agno 2.2.6__py3-none-any.whl → 2.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agno/api/os.py CHANGED
@@ -14,4 +14,4 @@ def log_os_telemetry(launch: OSLaunch) -> None:
14
14
  )
15
15
  response.raise_for_status()
16
16
  except Exception as e:
17
- log_debug(f"Could not create OS launch: {e}")
17
+ log_debug(f"Could not register OS launch for telemetry: {e}")
agno/culture/manager.py CHANGED
@@ -134,9 +134,10 @@ class CultureManager:
134
134
  if not self.db:
135
135
  return None
136
136
 
137
- self.db = cast(AsyncBaseDb, self.db)
138
-
139
- return await self.db.get_all_cultural_knowledge(name=name)
137
+ if isinstance(self.db, AsyncBaseDb):
138
+ return await self.db.get_all_cultural_knowledge(name=name)
139
+ else:
140
+ return self.db.get_all_cultural_knowledge(name=name)
140
141
 
141
142
  def add_cultural_knowledge(
142
143
  self,
@@ -230,7 +231,11 @@ class CultureManager:
230
231
  if not messages or not isinstance(messages, list):
231
232
  raise ValueError("Invalid messages list")
232
233
 
233
- knowledge = self.get_all_knowledge()
234
+ if isinstance(self.db, AsyncBaseDb):
235
+ knowledge = await self.aget_all_knowledge()
236
+ else:
237
+ knowledge = self.get_all_knowledge()
238
+
234
239
  if knowledge is None:
235
240
  knowledge = []
236
241
 
@@ -32,6 +32,7 @@ class PromptInjectionGuardrail(BaseGuardrail):
32
32
  "ignore safeguards",
33
33
  "admin override",
34
34
  "root access",
35
+ "forget everything",
35
36
  ]
36
37
 
37
38
  def check(self, run_input: Union[RunInput, TeamRunInput]) -> None:
@@ -0,0 +1,262 @@
1
+ import asyncio
2
+ from dataclasses import dataclass
3
+ from os import getenv
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
5
+
6
+ from agno.knowledge.embedder.base import Embedder
7
+ from agno.utils.log import logger
8
+
9
+ try:
10
+ from vllm import LLM # type: ignore
11
+ from vllm.outputs import EmbeddingRequestOutput # type: ignore
12
+ except ImportError:
13
+ raise ImportError("`vllm` not installed. Please install using `pip install vllm`.")
14
+
15
+ if TYPE_CHECKING:
16
+ from openai import AsyncOpenAI
17
+ from openai import OpenAI as OpenAIClient
18
+ from openai.types.create_embedding_response import CreateEmbeddingResponse
19
+
20
+
21
+ @dataclass
22
+ class VLLMEmbedder(Embedder):
23
+ """
24
+ VLLM Embedder supporting both local and remote deployment modes.
25
+
26
+ Local Mode (default):
27
+ - Loads model locally and runs inference on your GPU/CPU
28
+ - No API key required
29
+ - Example: VLLMEmbedder(id="intfloat/e5-mistral-7b-instruct")
30
+
31
+ Remote Mode:
32
+ - Connects to a remote vLLM server via OpenAI-compatible API
33
+ - Uses OpenAI SDK to communicate with vLLM's OpenAI-compatible endpoint
34
+ - Requires base_url and optionally api_key
35
+ - Example: VLLMEmbedder(base_url="http://localhost:8000/v1", api_key="your-key")
36
+ - Ref: https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html
37
+ """
38
+
39
+ id: str = "sentence-transformers/all-MiniLM-L6-v2"
40
+ dimensions: int = 4096
41
+ # Local mode parameters
42
+ enforce_eager: bool = True
43
+ vllm_kwargs: Optional[Dict[str, Any]] = None
44
+ vllm_client: Optional[LLM] = None
45
+ # Remote mode parameters
46
+ api_key: Optional[str] = getenv("VLLM_API_KEY")
47
+ base_url: Optional[str] = None
48
+ request_params: Optional[Dict[str, Any]] = None
49
+ client_params: Optional[Dict[str, Any]] = None
50
+ remote_client: Optional["OpenAIClient"] = None # OpenAI-compatible client for vLLM server
51
+ async_remote_client: Optional["AsyncOpenAI"] = None # Async OpenAI-compatible client for vLLM server
52
+
53
+ @property
54
+ def is_remote(self) -> bool:
55
+ """Determine if we should use remote mode."""
56
+ return self.base_url is not None
57
+
58
+ def _get_vllm_client(self) -> LLM:
59
+ """Get local VLLM client."""
60
+ if self.vllm_client:
61
+ return self.vllm_client
62
+
63
+ _vllm_params: Dict[str, Any] = {
64
+ "model": self.id,
65
+ "task": "embed",
66
+ "enforce_eager": self.enforce_eager,
67
+ }
68
+ if self.vllm_kwargs:
69
+ _vllm_params.update(self.vllm_kwargs)
70
+ self.vllm_client = LLM(**_vllm_params)
71
+ return self.vllm_client
72
+
73
+ def _get_remote_client(self) -> "OpenAIClient":
74
+ """Get OpenAI-compatible client for remote vLLM server."""
75
+ if self.remote_client:
76
+ return self.remote_client
77
+
78
+ try:
79
+ from openai import OpenAI as OpenAIClient
80
+ except ImportError:
81
+ raise ImportError("`openai` package required for remote vLLM mode. ")
82
+
83
+ _client_params: Dict[str, Any] = {
84
+ "api_key": self.api_key or "EMPTY", # VLLM can run without API key
85
+ "base_url": self.base_url,
86
+ }
87
+ if self.client_params:
88
+ _client_params.update(self.client_params)
89
+ self.remote_client = OpenAIClient(**_client_params)
90
+ return self.remote_client
91
+
92
+ def _get_async_remote_client(self) -> "AsyncOpenAI":
93
+ """Get async OpenAI-compatible client for remote vLLM server."""
94
+ if self.async_remote_client:
95
+ return self.async_remote_client
96
+
97
+ try:
98
+ from openai import AsyncOpenAI
99
+ except ImportError:
100
+ raise ImportError("`openai` package required for remote vLLM mode. ")
101
+
102
+ _client_params: Dict[str, Any] = {
103
+ "api_key": self.api_key or "EMPTY",
104
+ "base_url": self.base_url,
105
+ }
106
+ if self.client_params:
107
+ _client_params.update(self.client_params)
108
+ self.async_remote_client = AsyncOpenAI(**_client_params)
109
+ return self.async_remote_client
110
+
111
+ def _create_embedding_local(self, text: str) -> Optional[EmbeddingRequestOutput]:
112
+ """Create embedding using local VLLM."""
113
+ try:
114
+ outputs = self._get_vllm_client().embed([text])
115
+ return outputs[0] if outputs else None
116
+ except Exception as e:
117
+ logger.warning(f"Error creating local embedding: {e}")
118
+ return None
119
+
120
+ def _create_embedding_remote(self, text: str) -> "CreateEmbeddingResponse":
121
+ """Create embedding using remote vLLM server."""
122
+ _request_params: Dict[str, Any] = {
123
+ "input": text,
124
+ "model": self.id,
125
+ }
126
+ if self.request_params:
127
+ _request_params.update(self.request_params)
128
+ return self._get_remote_client().embeddings.create(**_request_params)
129
+
130
+ def get_embedding(self, text: str) -> List[float]:
131
+ try:
132
+ if self.is_remote:
133
+ # Remote mode: OpenAI-compatible API
134
+ response: "CreateEmbeddingResponse" = self._create_embedding_remote(text=text)
135
+ return response.data[0].embedding
136
+ else:
137
+ # Local mode: Direct VLLM
138
+ output = self._create_embedding_local(text=text)
139
+ if output and hasattr(output, "outputs") and hasattr(output.outputs, "embedding"):
140
+ embedding = output.outputs.embedding
141
+ if len(embedding) != self.dimensions:
142
+ logger.warning(f"Expected embedding dimension {self.dimensions}, but got {len(embedding)}")
143
+ return embedding
144
+ return []
145
+ except Exception as e:
146
+ logger.warning(f"Error extracting embedding: {e}")
147
+ return []
148
+
149
+ def get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
150
+ if self.is_remote:
151
+ try:
152
+ response: "CreateEmbeddingResponse" = self._create_embedding_remote(text=text)
153
+ embedding = response.data[0].embedding
154
+ usage = response.usage
155
+ if usage:
156
+ return embedding, usage.model_dump()
157
+ return embedding, None
158
+ except Exception as e:
159
+ logger.warning(f"Error in remote embedding: {e}")
160
+ return [], None
161
+ else:
162
+ embedding = self.get_embedding(text=text)
163
+ # Local VLLM doesn't provide usage information
164
+ return embedding, None
165
+
166
+ async def async_get_embedding(self, text: str) -> List[float]:
167
+ """Async version of get_embedding using thread executor for local mode."""
168
+ if self.is_remote:
169
+ # Remote mode: async client for vLLM server
170
+ try:
171
+ req: Dict[str, Any] = {
172
+ "input": text,
173
+ "model": self.id,
174
+ }
175
+ if self.request_params:
176
+ req.update(self.request_params)
177
+ response: "CreateEmbeddingResponse" = await self._get_async_remote_client().embeddings.create(**req)
178
+ return response.data[0].embedding
179
+ except Exception as e:
180
+ logger.warning(f"Error in async remote embedding: {e}")
181
+ return []
182
+ else:
183
+ # Local mode: use thread executor for CPU-bound operations
184
+ loop = asyncio.get_event_loop()
185
+ return await loop.run_in_executor(None, self.get_embedding, text)
186
+
187
+ async def async_get_embedding_and_usage(self, text: str) -> Tuple[List[float], Optional[Dict]]:
188
+ """Async version of get_embedding_and_usage using thread executor for local mode."""
189
+ if self.is_remote:
190
+ try:
191
+ req: Dict[str, Any] = {
192
+ "input": text,
193
+ "model": self.id,
194
+ }
195
+ if self.request_params:
196
+ req.update(self.request_params)
197
+ response: "CreateEmbeddingResponse" = await self._get_async_remote_client().embeddings.create(**req)
198
+ embedding = response.data[0].embedding
199
+ usage = response.usage
200
+ return embedding, usage.model_dump() if usage else None
201
+ except Exception as e:
202
+ logger.warning(f"Error in async remote embedding: {e}")
203
+ return [], None
204
+ else:
205
+ # Local mode: use thread executor for CPU-bound operations
206
+ try:
207
+ loop = asyncio.get_event_loop()
208
+ return await loop.run_in_executor(None, self.get_embedding_and_usage, text)
209
+ except Exception as e:
210
+ logger.warning(f"Error in async local embedding: {e}")
211
+ return [], None
212
+
213
+ async def async_get_embeddings_batch_and_usage(
214
+ self, texts: List[str]
215
+ ) -> Tuple[List[List[float]], List[Optional[Dict]]]:
216
+ """
217
+ Get embeddings and usage for multiple texts in batches (async version).
218
+
219
+ Args:
220
+ texts: List of text strings to embed
221
+
222
+ Returns:
223
+ Tuple of (List of embedding vectors, List of usage dictionaries)
224
+ """
225
+ all_embeddings = []
226
+ all_usage = []
227
+ logger.info(f"Getting embeddings for {len(texts)} texts in batches of {self.batch_size} (async)")
228
+
229
+ for i in range(0, len(texts), self.batch_size):
230
+ batch_texts = texts[i : i + self.batch_size]
231
+
232
+ try:
233
+ if self.is_remote:
234
+ # Remote mode: use batch API
235
+ req: Dict[str, Any] = {
236
+ "input": batch_texts,
237
+ "model": self.id,
238
+ }
239
+ if self.request_params:
240
+ req.update(self.request_params)
241
+ response: "CreateEmbeddingResponse" = await self._get_async_remote_client().embeddings.create(**req)
242
+ batch_embeddings = [data.embedding for data in response.data]
243
+ all_embeddings.extend(batch_embeddings)
244
+
245
+ # For each embedding in the batch, add the same usage information
246
+ usage_dict = response.usage.model_dump() if response.usage else None
247
+ all_usage.extend([usage_dict] * len(batch_embeddings))
248
+ else:
249
+ # Local mode: process individually using thread executor
250
+ for text in batch_texts:
251
+ embedding, usage = await self.async_get_embedding_and_usage(text)
252
+ all_embeddings.append(embedding)
253
+ all_usage.append(usage)
254
+
255
+ except Exception as e:
256
+ logger.warning(f"Error in async batch embedding: {e}")
257
+ # Fallback: add empty results for failed batch
258
+ for _ in batch_texts:
259
+ all_embeddings.append([])
260
+ all_usage.append(None)
261
+
262
+ return all_embeddings, all_usage
@@ -4,7 +4,6 @@ import io
4
4
  import time
5
5
  from dataclasses import dataclass
6
6
  from enum import Enum
7
- from functools import cached_property
8
7
  from io import BytesIO
9
8
  from os.path import basename
10
9
  from pathlib import Path
@@ -187,10 +186,14 @@ class Knowledge:
187
186
  paths: Optional[List[str]] = None,
188
187
  urls: Optional[List[str]] = None,
189
188
  metadata: Optional[Dict[str, str]] = None,
189
+ topics: Optional[List[str]] = None,
190
+ text_contents: Optional[List[str]] = None,
191
+ reader: Optional[Reader] = None,
190
192
  include: Optional[List[str]] = None,
191
193
  exclude: Optional[List[str]] = None,
192
194
  upsert: bool = True,
193
195
  skip_if_exists: bool = False,
196
+ remote_content: Optional[RemoteContent] = None,
194
197
  ) -> None: ...
195
198
 
196
199
  def add_contents(self, *args, **kwargs) -> None:
@@ -208,10 +211,14 @@ class Knowledge:
208
211
  paths: Optional list of file paths to load content from
209
212
  urls: Optional list of URLs to load content from
210
213
  metadata: Optional metadata dictionary to apply to all content
214
+ topics: Optional list of topics to add
215
+ text_contents: Optional list of text content strings to add
216
+ reader: Optional reader to use for processing content
211
217
  include: Optional list of file patterns to include
212
218
  exclude: Optional list of file patterns to exclude
213
219
  upsert: Whether to update existing content if it already exists
214
220
  skip_if_exists: Whether to skip adding content if it already exists
221
+ remote_content: Optional remote content (S3, GCS, etc.) to add
215
222
  """
216
223
  asyncio.run(self.add_contents_async(*args, **kwargs))
217
224
 
@@ -1449,14 +1456,16 @@ class Knowledge:
1449
1456
  def get_valid_filters(self) -> Set[str]:
1450
1457
  if self.valid_metadata_filters is None:
1451
1458
  self.valid_metadata_filters = set()
1452
- self.valid_metadata_filters.update(self._get_filters_from_db)
1459
+ self.valid_metadata_filters.update(self._get_filters_from_db())
1453
1460
  return self.valid_metadata_filters
1454
1461
 
1455
- def validate_filters(self, filters: Optional[Dict[str, Any]]) -> Tuple[Dict[str, Any], List[str]]:
1462
+ async def aget_valid_filters(self) -> Set[str]:
1456
1463
  if self.valid_metadata_filters is None:
1457
1464
  self.valid_metadata_filters = set()
1458
- self.valid_metadata_filters.update(self._get_filters_from_db)
1465
+ self.valid_metadata_filters.update(await self._aget_filters_from_db())
1466
+ return self.valid_metadata_filters
1459
1467
 
1468
+ def _validate_filters(self, filters: Optional[Dict[str, Any]]) -> Tuple[Dict[str, Any], List[str]]:
1460
1469
  if not filters:
1461
1470
  return {}, []
1462
1471
 
@@ -1480,6 +1489,20 @@ class Knowledge:
1480
1489
 
1481
1490
  return valid_filters, invalid_keys
1482
1491
 
1492
+ def validate_filters(self, filters: Optional[Dict[str, Any]]) -> Tuple[Dict[str, Any], List[str]]:
1493
+ if self.valid_metadata_filters is None:
1494
+ self.valid_metadata_filters = set()
1495
+ self.valid_metadata_filters.update(self._get_filters_from_db())
1496
+
1497
+ return self._validate_filters(filters)
1498
+
1499
+ async def async_validate_filters(self, filters: Optional[Dict[str, Any]]) -> Tuple[Dict[str, Any], List[str]]:
1500
+ if self.valid_metadata_filters is None:
1501
+ self.valid_metadata_filters = set()
1502
+ self.valid_metadata_filters.update(await self._aget_filters_from_db())
1503
+
1504
+ return self._validate_filters(filters)
1505
+
1483
1506
  def add_filters(self, metadata: Dict[str, Any]) -> None:
1484
1507
  if self.valid_metadata_filters is None:
1485
1508
  self.valid_metadata_filters = set()
@@ -1488,7 +1511,6 @@ class Knowledge:
1488
1511
  for key in metadata.keys():
1489
1512
  self.valid_metadata_filters.add(key)
1490
1513
 
1491
- @cached_property
1492
1514
  def _get_filters_from_db(self) -> Set[str]:
1493
1515
  if self.contents_db is None:
1494
1516
  return set()
@@ -1499,6 +1521,16 @@ class Knowledge:
1499
1521
  valid_filters.update(content.metadata.keys())
1500
1522
  return valid_filters
1501
1523
 
1524
+ async def _aget_filters_from_db(self) -> Set[str]:
1525
+ if self.contents_db is None:
1526
+ return set()
1527
+ contents, _ = await self.aget_content()
1528
+ valid_filters: Set[str] = set()
1529
+ for content in contents:
1530
+ if content.metadata:
1531
+ valid_filters.update(content.metadata.keys())
1532
+ return valid_filters
1533
+
1502
1534
  def remove_vector_by_id(self, id: str) -> bool:
1503
1535
  from agno.vectordb import VectorDb
1504
1536
 
agno/models/base.py CHANGED
@@ -31,7 +31,7 @@ from agno.models.metrics import Metrics
31
31
  from agno.models.response import ModelResponse, ModelResponseEvent, ToolExecution
32
32
  from agno.run.agent import CustomEvent, RunContentEvent, RunOutput, RunOutputEvent
33
33
  from agno.run.team import RunContentEvent as TeamRunContentEvent
34
- from agno.run.team import TeamRunOutputEvent
34
+ from agno.run.team import TeamRunOutput, TeamRunOutputEvent
35
35
  from agno.run.workflow import WorkflowRunOutputEvent
36
36
  from agno.tools.function import Function, FunctionCall, FunctionExecutionResult, UserInputField
37
37
  from agno.utils.log import log_debug, log_error, log_info, log_warning
@@ -53,6 +53,8 @@ class MessageData:
53
53
  response_video: Optional[Video] = None
54
54
  response_file: Optional[File] = None
55
55
 
56
+ response_metrics: Optional[Metrics] = None
57
+
56
58
  # Data from the provider that we might need on subsequent messages
57
59
  response_provider_data: Optional[Dict[str, Any]] = None
58
60
 
@@ -308,7 +310,7 @@ class Model(ABC):
308
310
  tools: Optional[List[Union[Function, dict]]] = None,
309
311
  tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
310
312
  tool_call_limit: Optional[int] = None,
311
- run_response: Optional[RunOutput] = None,
313
+ run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
312
314
  send_media_to_model: bool = True,
313
315
  ) -> ModelResponse:
314
316
  """
@@ -482,6 +484,7 @@ class Model(ABC):
482
484
  tools: Optional[List[Union[Function, dict]]] = None,
483
485
  tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
484
486
  tool_call_limit: Optional[int] = None,
487
+ run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
485
488
  send_media_to_model: bool = True,
486
489
  ) -> ModelResponse:
487
490
  """
@@ -517,6 +520,7 @@ class Model(ABC):
517
520
  response_format=response_format,
518
521
  tools=_tool_dicts,
519
522
  tool_choice=tool_choice or self._tool_choice,
523
+ run_response=run_response,
520
524
  )
521
525
 
522
526
  # Add assistant message to messages
@@ -644,7 +648,7 @@ class Model(ABC):
644
648
  response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
645
649
  tools: Optional[List[Dict[str, Any]]] = None,
646
650
  tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
647
- run_response: Optional[RunOutput] = None,
651
+ run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
648
652
  ) -> None:
649
653
  """
650
654
  Process a single model response and return the assistant message and whether to continue.
@@ -697,7 +701,7 @@ class Model(ABC):
697
701
  response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
698
702
  tools: Optional[List[Dict[str, Any]]] = None,
699
703
  tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
700
- run_response: Optional[RunOutput] = None,
704
+ run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
701
705
  ) -> None:
702
706
  """
703
707
  Process a single async model response and return the assistant message and whether to continue.
@@ -757,7 +761,6 @@ class Model(ABC):
757
761
  Returns:
758
762
  Message: The populated assistant message
759
763
  """
760
- # Add role to assistant message
761
764
  if provider_response.role is not None:
762
765
  assistant_message.role = provider_response.role
763
766
 
@@ -821,7 +824,7 @@ class Model(ABC):
821
824
  response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
822
825
  tools: Optional[List[Dict[str, Any]]] = None,
823
826
  tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
824
- run_response: Optional[RunOutput] = None,
827
+ run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
825
828
  ) -> Iterator[ModelResponse]:
826
829
  """
827
830
  Process a streaming response from the model.
@@ -835,14 +838,14 @@ class Model(ABC):
835
838
  tool_choice=tool_choice or self._tool_choice,
836
839
  run_response=run_response,
837
840
  ):
838
- yield from self._populate_stream_data_and_assistant_message(
841
+ for model_response_delta in self._populate_stream_data(
839
842
  stream_data=stream_data,
840
- assistant_message=assistant_message,
841
843
  model_response_delta=response_delta,
842
- )
844
+ ):
845
+ yield model_response_delta
843
846
 
844
- # Add final metrics to assistant message
845
- self._populate_assistant_message(assistant_message=assistant_message, provider_response=response_delta)
847
+ # Populate assistant message from stream data after the stream ends
848
+ self._populate_assistant_message_from_stream_data(assistant_message=assistant_message, stream_data=stream_data)
846
849
 
847
850
  def response_stream(
848
851
  self,
@@ -852,7 +855,7 @@ class Model(ABC):
852
855
  tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
853
856
  tool_call_limit: Optional[int] = None,
854
857
  stream_model_response: bool = True,
855
- run_response: Optional[RunOutput] = None,
858
+ run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
856
859
  send_media_to_model: bool = True,
857
860
  ) -> Iterator[Union[ModelResponse, RunOutputEvent, TeamRunOutputEvent]]:
858
861
  """
@@ -906,22 +909,6 @@ class Model(ABC):
906
909
  streaming_responses.append(response)
907
910
  yield response
908
911
 
909
- # Populate assistant message from stream data
910
- if stream_data.response_content:
911
- assistant_message.content = stream_data.response_content
912
- if stream_data.response_reasoning_content:
913
- assistant_message.reasoning_content = stream_data.response_reasoning_content
914
- if stream_data.response_redacted_reasoning_content:
915
- assistant_message.redacted_reasoning_content = stream_data.response_redacted_reasoning_content
916
- if stream_data.response_provider_data:
917
- assistant_message.provider_data = stream_data.response_provider_data
918
- if stream_data.response_citations:
919
- assistant_message.citations = stream_data.response_citations
920
- if stream_data.response_audio:
921
- assistant_message.audio_output = stream_data.response_audio
922
- if stream_data.response_tool_calls and len(stream_data.response_tool_calls) > 0:
923
- assistant_message.tool_calls = self.parse_tool_calls(stream_data.response_tool_calls)
924
-
925
912
  else:
926
913
  self._process_model_response(
927
914
  messages=messages,
@@ -1020,7 +1007,7 @@ class Model(ABC):
1020
1007
  response_format: Optional[Union[Dict, Type[BaseModel]]] = None,
1021
1008
  tools: Optional[List[Dict[str, Any]]] = None,
1022
1009
  tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
1023
- run_response: Optional[RunOutput] = None,
1010
+ run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
1024
1011
  ) -> AsyncIterator[ModelResponse]:
1025
1012
  """
1026
1013
  Process a streaming response from the model.
@@ -1033,15 +1020,14 @@ class Model(ABC):
1033
1020
  tool_choice=tool_choice or self._tool_choice,
1034
1021
  run_response=run_response,
1035
1022
  ): # type: ignore
1036
- for model_response in self._populate_stream_data_and_assistant_message(
1023
+ for model_response_delta in self._populate_stream_data(
1037
1024
  stream_data=stream_data,
1038
- assistant_message=assistant_message,
1039
1025
  model_response_delta=response_delta,
1040
1026
  ):
1041
- yield model_response
1027
+ yield model_response_delta
1042
1028
 
1043
- # Populate the assistant message
1044
- self._populate_assistant_message(assistant_message=assistant_message, provider_response=model_response)
1029
+ # Populate assistant message from stream data after the stream ends
1030
+ self._populate_assistant_message_from_stream_data(assistant_message=assistant_message, stream_data=stream_data)
1045
1031
 
1046
1032
  async def aresponse_stream(
1047
1033
  self,
@@ -1051,7 +1037,7 @@ class Model(ABC):
1051
1037
  tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
1052
1038
  tool_call_limit: Optional[int] = None,
1053
1039
  stream_model_response: bool = True,
1054
- run_response: Optional[RunOutput] = None,
1040
+ run_response: Optional[Union[RunOutput, TeamRunOutput]] = None,
1055
1041
  send_media_to_model: bool = True,
1056
1042
  ) -> AsyncIterator[Union[ModelResponse, RunOutputEvent, TeamRunOutputEvent]]:
1057
1043
  """
@@ -1105,20 +1091,6 @@ class Model(ABC):
1105
1091
  streaming_responses.append(model_response)
1106
1092
  yield model_response
1107
1093
 
1108
- # Populate assistant message from stream data
1109
- if stream_data.response_content:
1110
- assistant_message.content = stream_data.response_content
1111
- if stream_data.response_reasoning_content:
1112
- assistant_message.reasoning_content = stream_data.response_reasoning_content
1113
- if stream_data.response_redacted_reasoning_content:
1114
- assistant_message.redacted_reasoning_content = stream_data.response_redacted_reasoning_content
1115
- if stream_data.response_provider_data:
1116
- assistant_message.provider_data = stream_data.response_provider_data
1117
- if stream_data.response_audio:
1118
- assistant_message.audio_output = stream_data.response_audio
1119
- if stream_data.response_tool_calls and len(stream_data.response_tool_calls) > 0:
1120
- assistant_message.tool_calls = self.parse_tool_calls(stream_data.response_tool_calls)
1121
-
1122
1094
  else:
1123
1095
  await self._aprocess_model_response(
1124
1096
  messages=messages,
@@ -1210,15 +1182,51 @@ class Model(ABC):
1210
1182
  if self.cache_response and cache_key and streaming_responses:
1211
1183
  self._save_streaming_responses_to_cache(cache_key, streaming_responses)
1212
1184
 
1213
- def _populate_stream_data_and_assistant_message(
1214
- self, stream_data: MessageData, assistant_message: Message, model_response_delta: ModelResponse
1185
+ def _populate_assistant_message_from_stream_data(
1186
+ self, assistant_message: Message, stream_data: MessageData
1187
+ ) -> None:
1188
+ """
1189
+ Populate an assistant message with the stream data.
1190
+ """
1191
+ if stream_data.response_role is not None:
1192
+ assistant_message.role = stream_data.response_role
1193
+ if stream_data.response_metrics is not None:
1194
+ assistant_message.metrics = stream_data.response_metrics
1195
+ if stream_data.response_content:
1196
+ assistant_message.content = stream_data.response_content
1197
+ if stream_data.response_reasoning_content:
1198
+ assistant_message.reasoning_content = stream_data.response_reasoning_content
1199
+ if stream_data.response_redacted_reasoning_content:
1200
+ assistant_message.redacted_reasoning_content = stream_data.response_redacted_reasoning_content
1201
+ if stream_data.response_provider_data:
1202
+ assistant_message.provider_data = stream_data.response_provider_data
1203
+ if stream_data.response_citations:
1204
+ assistant_message.citations = stream_data.response_citations
1205
+ if stream_data.response_audio:
1206
+ assistant_message.audio_output = stream_data.response_audio
1207
+ if stream_data.response_image:
1208
+ assistant_message.image_output = stream_data.response_image
1209
+ if stream_data.response_video:
1210
+ assistant_message.video_output = stream_data.response_video
1211
+ if stream_data.response_file:
1212
+ assistant_message.file_output = stream_data.response_file
1213
+ if stream_data.response_tool_calls and len(stream_data.response_tool_calls) > 0:
1214
+ assistant_message.tool_calls = self.parse_tool_calls(stream_data.response_tool_calls)
1215
+
1216
+ def _populate_stream_data(
1217
+ self, stream_data: MessageData, model_response_delta: ModelResponse
1215
1218
  ) -> Iterator[ModelResponse]:
1216
1219
  """Update the stream data and assistant message with the model response."""
1217
- # Add role to assistant message
1218
- if model_response_delta.role is not None:
1219
- assistant_message.role = model_response_delta.role
1220
1220
 
1221
1221
  should_yield = False
1222
+ if model_response_delta.role is not None:
1223
+ stream_data.response_role = model_response_delta.role # type: ignore
1224
+
1225
+ if model_response_delta.response_usage is not None:
1226
+ if stream_data.response_metrics is None:
1227
+ stream_data.response_metrics = Metrics()
1228
+ stream_data.response_metrics += model_response_delta.response_usage
1229
+
1222
1230
  # Update stream_data content
1223
1231
  if model_response_delta.content is not None:
1224
1232
  stream_data.response_content += model_response_delta.content