agno 2.4.6__py3-none-any.whl → 2.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,299 @@
1
+ from os import getenv
2
+ from typing import Any, Dict, List, Literal, Optional
3
+
4
+ from pydantic import ConfigDict, Field
5
+
6
+ from agno.knowledge.document import Document
7
+ from agno.knowledge.reranker.base import Reranker
8
+ from agno.utils.log import logger
9
+
10
+ try:
11
+ from boto3 import client as AwsClient
12
+ from boto3.session import Session
13
+ from botocore.exceptions import ClientError
14
+ except ImportError:
15
+ raise ImportError("`boto3` not installed. Please install it via `pip install boto3`.")
16
+
17
+
18
+ # Model ID constants
19
+ AMAZON_RERANK_V1 = "amazon.rerank-v1:0"
20
+ COHERE_RERANK_V3_5 = "cohere.rerank-v3-5:0"
21
+
22
+ # Type alias for supported models
23
+ RerankerModel = Literal["amazon.rerank-v1:0", "cohere.rerank-v3-5:0"]
24
+
25
+
26
+ class AwsBedrockReranker(Reranker):
27
+ """
28
+ AWS Bedrock reranker supporting Amazon Rerank 1.0 and Cohere Rerank 3.5 models.
29
+
30
+ This reranker uses the unified Bedrock Rerank API (bedrock-agent-runtime)
31
+ which provides a consistent interface for both model providers.
32
+
33
+ To use this reranker, you need to either:
34
+ 1. Set the following environment variables:
35
+ - AWS_ACCESS_KEY_ID
36
+ - AWS_SECRET_ACCESS_KEY
37
+ - AWS_REGION
38
+ 2. Or provide a boto3 Session object
39
+
40
+ Args:
41
+ model (str): The model ID to use. Options:
42
+ - 'amazon.rerank-v1:0' (Amazon Rerank 1.0)
43
+ - 'cohere.rerank-v3-5:0' (Cohere Rerank 3.5)
44
+ Default is 'cohere.rerank-v3-5:0'.
45
+ top_n (Optional[int]): Number of top results to return after reranking.
46
+ If None, returns all documents reranked.
47
+ aws_region (Optional[str]): The AWS region to use.
48
+ aws_access_key_id (Optional[str]): The AWS access key ID to use.
49
+ aws_secret_access_key (Optional[str]): The AWS secret access key to use.
50
+ session (Optional[Session]): A boto3 Session object for authentication.
51
+ additional_model_request_fields (Optional[Dict]): Additional model-specific
52
+ parameters to pass in the request (e.g., Cohere-specific options).
53
+
54
+ Example:
55
+ ```python
56
+ from agno.knowledge.reranker.aws_bedrock import AwsBedrockReranker
57
+
58
+ # Using Cohere Rerank 3.5 (default)
59
+ reranker = AwsBedrockReranker(
60
+ model="cohere.rerank-v3-5:0",
61
+ top_n=5,
62
+ aws_region="us-west-2",
63
+ )
64
+
65
+ # Using Amazon Rerank 1.0
66
+ reranker = AwsBedrockReranker(
67
+ model="amazon.rerank-v1:0",
68
+ top_n=10,
69
+ aws_region="us-west-2",
70
+ )
71
+
72
+ # Rerank documents
73
+ reranked_docs = reranker.rerank(query="What is machine learning?", documents=docs)
74
+ ```
75
+
76
+ Note:
77
+ - Amazon Rerank 1.0 is NOT available in us-east-1 (N. Virginia).
78
+ Use Cohere Rerank 3.5 in that region.
79
+ - Maximum 1000 documents per request.
80
+ """
81
+
82
+ model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
83
+
84
+ model: str = Field(default=COHERE_RERANK_V3_5, description="Reranker model ID")
85
+ top_n: Optional[int] = Field(default=None, description="Number of top results to return")
86
+
87
+ aws_region: Optional[str] = Field(default=None, description="AWS region")
88
+ aws_access_key_id: Optional[str] = Field(default=None, description="AWS access key ID")
89
+ aws_secret_access_key: Optional[str] = Field(default=None, description="AWS secret access key")
90
+ session: Optional[Session] = Field(default=None, description="Boto3 session", exclude=True)
91
+
92
+ additional_model_request_fields: Optional[Dict[str, Any]] = Field(
93
+ default=None,
94
+ description="Additional model-specific request parameters",
95
+ )
96
+
97
+ _client: Optional[AwsClient] = None
98
+
99
+ @property
100
+ def client(self) -> AwsClient:
101
+ """
102
+ Returns a bedrock-agent-runtime client for the Rerank API.
103
+
104
+ Returns:
105
+ AwsClient: An instance of the bedrock-agent-runtime client.
106
+ """
107
+ if self._client is not None:
108
+ return self._client
109
+
110
+ if self.session:
111
+ self._client = self.session.client("bedrock-agent-runtime")
112
+ return self._client
113
+
114
+ aws_access_key_id = self.aws_access_key_id or getenv("AWS_ACCESS_KEY_ID")
115
+ aws_secret_access_key = self.aws_secret_access_key or getenv("AWS_SECRET_ACCESS_KEY")
116
+ aws_region = self.aws_region or getenv("AWS_REGION")
117
+
118
+ if not aws_access_key_id or not aws_secret_access_key:
119
+ # Fall back to default credential chain
120
+ self._client = AwsClient(
121
+ service_name="bedrock-agent-runtime",
122
+ region_name=aws_region,
123
+ )
124
+ else:
125
+ self._client = AwsClient(
126
+ service_name="bedrock-agent-runtime",
127
+ region_name=aws_region,
128
+ aws_access_key_id=aws_access_key_id,
129
+ aws_secret_access_key=aws_secret_access_key,
130
+ )
131
+
132
+ return self._client
133
+
134
+ def _get_model_arn(self) -> str:
135
+ """
136
+ Constructs the full model ARN for the reranker model.
137
+
138
+ Returns:
139
+ str: The model ARN.
140
+ """
141
+ region = self.aws_region or getenv("AWS_REGION", "us-west-2")
142
+ return f"arn:aws:bedrock:{region}::foundation-model/{self.model}"
143
+
144
+ def _build_sources(self, documents: List[Document]) -> List[Dict[str, Any]]:
145
+ """
146
+ Convert Document objects to Bedrock Rerank API source format.
147
+
148
+ Args:
149
+ documents: List of Document objects to convert.
150
+
151
+ Returns:
152
+ List of RerankSource objects for the API.
153
+ """
154
+ sources = []
155
+ for doc in documents:
156
+ # Use text format for document content
157
+ source = {
158
+ "type": "INLINE",
159
+ "inlineDocumentSource": {
160
+ "type": "TEXT",
161
+ "textDocument": {
162
+ "text": doc.content,
163
+ },
164
+ },
165
+ }
166
+ sources.append(source)
167
+ return sources
168
+
169
+ def _rerank(self, query: str, documents: List[Document]) -> List[Document]:
170
+ """
171
+ Internal method to perform reranking via Bedrock Rerank API.
172
+
173
+ Args:
174
+ query: The query string to rank documents against.
175
+ documents: List of Document objects to rerank.
176
+
177
+ Returns:
178
+ List of Document objects sorted by relevance score.
179
+ """
180
+ if not documents:
181
+ return []
182
+
183
+ # Validate top_n
184
+ top_n = self.top_n
185
+ if top_n is not None and top_n <= 0:
186
+ logger.warning(f"top_n should be a positive integer, got {self.top_n}, setting top_n to None")
187
+ top_n = None
188
+
189
+ # Build the request
190
+ rerank_request: Dict[str, Any] = {
191
+ "queries": [
192
+ {
193
+ "type": "TEXT",
194
+ "textQuery": {
195
+ "text": query,
196
+ },
197
+ }
198
+ ],
199
+ "sources": self._build_sources(documents),
200
+ "rerankingConfiguration": {
201
+ "type": "BEDROCK_RERANKING_MODEL",
202
+ "bedrockRerankingConfiguration": {
203
+ "modelConfiguration": {
204
+ "modelArn": self._get_model_arn(),
205
+ },
206
+ },
207
+ },
208
+ }
209
+
210
+ # Add numberOfResults if top_n is specified
211
+ if top_n is not None:
212
+ rerank_request["rerankingConfiguration"]["bedrockRerankingConfiguration"]["numberOfResults"] = top_n
213
+
214
+ # Add additional model request fields if provided
215
+ if self.additional_model_request_fields:
216
+ rerank_request["rerankingConfiguration"]["bedrockRerankingConfiguration"]["modelConfiguration"][
217
+ "additionalModelRequestFields"
218
+ ] = self.additional_model_request_fields
219
+
220
+ # Call the Rerank API
221
+ response = self.client.rerank(**rerank_request)
222
+
223
+ # Process results
224
+ reranked_docs: List[Document] = []
225
+ results = response.get("results", [])
226
+
227
+ for result in results:
228
+ index = result.get("index")
229
+ relevance_score = result.get("relevanceScore")
230
+
231
+ if index is not None and index < len(documents):
232
+ doc = documents[index]
233
+ doc.reranking_score = relevance_score
234
+ reranked_docs.append(doc)
235
+
236
+ # Results from API are already sorted by relevance, but ensure sorting
237
+ reranked_docs.sort(
238
+ key=lambda x: x.reranking_score if x.reranking_score is not None else float("-inf"),
239
+ reverse=True,
240
+ )
241
+
242
+ return reranked_docs
243
+
244
+ def rerank(self, query: str, documents: List[Document]) -> List[Document]:
245
+ """
246
+ Rerank documents based on their relevance to a query.
247
+
248
+ Args:
249
+ query: The query string to rank documents against.
250
+ documents: List of Document objects to rerank.
251
+
252
+ Returns:
253
+ List of Document objects sorted by relevance score (highest first).
254
+ Each document will have its `reranking_score` attribute set.
255
+ """
256
+ try:
257
+ return self._rerank(query=query, documents=documents)
258
+ except ClientError as e:
259
+ error_code = e.response.get("Error", {}).get("Code", "Unknown")
260
+ error_message = e.response.get("Error", {}).get("Message", str(e))
261
+ logger.error(f"AWS Bedrock Rerank API error ({error_code}): {error_message}. Returning original documents.")
262
+ return documents
263
+ except Exception as e:
264
+ logger.error(f"Error reranking documents: {e}. Returning original documents.")
265
+ return documents
266
+
267
+
268
+ class CohereBedrockReranker(AwsBedrockReranker):
269
+ """
270
+ Convenience class for Cohere Rerank 3.5 on AWS Bedrock.
271
+
272
+ This is a pre-configured AwsBedrockReranker using the Cohere Rerank 3.5 model.
273
+
274
+ Example:
275
+ ```python
276
+ reranker = CohereBedrockReranker(top_n=5, aws_region="us-west-2")
277
+ reranked_docs = reranker.rerank(query="What is AI?", documents=docs)
278
+ ```
279
+ """
280
+
281
+ model: str = Field(default=COHERE_RERANK_V3_5)
282
+
283
+
284
+ class AmazonReranker(AwsBedrockReranker):
285
+ """
286
+ Convenience class for Amazon Rerank 1.0 on AWS Bedrock.
287
+
288
+ This is a pre-configured AwsBedrockReranker using the Amazon Rerank 1.0 model.
289
+
290
+ Note: Amazon Rerank 1.0 is NOT available in us-east-1 (N. Virginia).
291
+
292
+ Example:
293
+ ```python
294
+ reranker = AmazonReranker(top_n=5, aws_region="us-west-2")
295
+ reranked_docs = reranker.rerank(query="What is AI?", documents=docs)
296
+ ```
297
+ """
298
+
299
+ model: str = Field(default=AMAZON_RERANK_V1)
agno/learn/machine.py CHANGED
@@ -645,12 +645,11 @@ class LearningMachine:
645
645
  for name, store in self.stores.items():
646
646
  try:
647
647
  result = await store.arecall(**context)
648
- if result is not None:
649
- results[name] = result
650
- try:
651
- log_debug(f"Recalled from {name}: {result}")
652
- except Exception:
653
- pass
648
+ results[name] = result
649
+ try:
650
+ log_debug(f"Recalled from {name}: {result}")
651
+ except Exception:
652
+ pass
654
653
  except Exception as e:
655
654
  log_warning(f"Error recalling from {name}: {e}")
656
655
 
agno/run/workflow.py CHANGED
@@ -311,6 +311,9 @@ class ConditionExecutionCompletedEvent(BaseWorkflowRunOutputEvent):
311
311
  condition_result: Optional[bool] = None
312
312
  executed_steps: Optional[int] = None
313
313
 
314
+ # Which branch was executed: "if", "else", or None (condition false with no else_steps)
315
+ branch: Optional[str] = None
316
+
314
317
  # Results from executed steps
315
318
  step_results: List[StepOutput] = field(default_factory=list)
316
319
 
agno/tools/mcp/mcp.py CHANGED
@@ -74,6 +74,13 @@ class MCPTools(Toolkit):
74
74
  Only relevant with HTTP transports (Streamable HTTP or SSE).
75
75
  Creates a new session per agent run with dynamic headers merged into connection config.
76
76
  """
77
+ # Extract these before super().__init__() to bypass early validation
78
+ # (tools aren't available until build_tools() is called)
79
+ requires_confirmation_tools = kwargs.pop("requires_confirmation_tools", None)
80
+ external_execution_required_tools = kwargs.pop("external_execution_required_tools", None)
81
+ stop_after_tool_call_tools = kwargs.pop("stop_after_tool_call_tools", None)
82
+ show_result_tools = kwargs.pop("show_result_tools", None)
83
+
77
84
  super().__init__(name="MCPTools", **kwargs)
78
85
 
79
86
  if url is not None:
@@ -92,6 +99,10 @@ class MCPTools(Toolkit):
92
99
  # because tools are not available until `initialize()` is called.
93
100
  self.include_tools = include_tools
94
101
  self.exclude_tools = exclude_tools
102
+ self.requires_confirmation_tools = requires_confirmation_tools or []
103
+ self.external_execution_required_tools = external_execution_required_tools or []
104
+ self.stop_after_tool_call_tools = stop_after_tool_call_tools or []
105
+ self.show_result_tools = show_result_tools or []
95
106
  self.refresh_connection = refresh_connection
96
107
  self.tool_name_prefix = tool_name_prefix
97
108
 
@@ -575,13 +586,27 @@ class MCPTools(Toolkit):
575
586
  mcp_tools_instance=self,
576
587
  )
577
588
  # Create a Function for the tool
589
+ # Apply toolkit-level settings
590
+ tool_name = tool.name
591
+ stop_after = tool_name in self.stop_after_tool_call_tools
592
+ show_result = tool_name in self.show_result_tools or stop_after
593
+
578
594
  f = Function(
579
- name=tool_name_prefix + tool.name,
595
+ name=tool_name_prefix + tool_name,
580
596
  description=tool.description,
581
597
  parameters=tool.inputSchema,
582
598
  entrypoint=entrypoint,
583
599
  # Set skip_entrypoint_processing to True to avoid processing the entrypoint
584
600
  skip_entrypoint_processing=True,
601
+ # Apply toolkit-level settings for HITL and control flow
602
+ requires_confirmation=tool_name in self.requires_confirmation_tools,
603
+ external_execution=tool_name in self.external_execution_required_tools,
604
+ stop_after_tool_call=stop_after,
605
+ show_result=show_result,
606
+ # Apply toolkit-level cache settings
607
+ cache_results=self.cache_results,
608
+ cache_dir=self.cache_dir,
609
+ cache_ttl=self.cache_ttl,
585
610
  )
586
611
 
587
612
  # Register the Function with the toolkit
@@ -104,7 +104,7 @@ class LanceDb(VectorDb):
104
104
  self.async_connection: Optional[lancedb.AsyncConnection] = async_connection
105
105
  self.async_table: Optional[lancedb.db.AsyncTable] = async_table
106
106
 
107
- if table_name and table_name in self.connection.table_names():
107
+ if table_name and table_name in self.connection.list_tables().tables:
108
108
  # Open the table if it exists
109
109
  try:
110
110
  self.table = self.connection.open_table(name=table_name)
@@ -186,8 +186,8 @@ class LanceDb(VectorDb):
186
186
  self.async_connection = await lancedb.connect_async(self.uri)
187
187
  # Only try to open table if it exists and we don't have it already
188
188
  if self.async_table is None:
189
- table_names = await self.async_connection.table_names()
190
- if self.table_name in table_names:
189
+ table_list = await self.async_connection.list_tables()
190
+ if self.table_name in table_list.tables:
191
191
  try:
192
192
  self.async_table = await self.async_connection.open_table(self.table_name)
193
193
  except ValueError:
@@ -199,7 +199,7 @@ class LanceDb(VectorDb):
199
199
  """Refresh the sync connection to see changes made by async operations."""
200
200
  try:
201
201
  # Re-establish sync connection to see async changes
202
- if self.connection and self.table_name in self.connection.table_names():
202
+ if self.connection is not None and self.table_name in self.connection.list_tables().tables:
203
203
  self.table = self.connection.open_table(self.table_name)
204
204
  except Exception as e:
205
205
  log_debug(f"Could not refresh sync connection: {e}")
@@ -459,7 +459,7 @@ class LanceDb(VectorDb):
459
459
  Returns:
460
460
  List[Document]: List of matching documents
461
461
  """
462
- if self.connection:
462
+ if self.connection is not None:
463
463
  self.table = self.connection.open_table(name=self.table_name)
464
464
 
465
465
  results = None
@@ -641,8 +641,8 @@ class LanceDb(VectorDb):
641
641
  # If we have an async table that was created, the table exists
642
642
  if self.async_table is not None:
643
643
  return True
644
- if self.connection:
645
- return self.table_name in self.connection.table_names()
644
+ if self.connection is not None:
645
+ return self.table_name in self.connection.list_tables().tables
646
646
  return False
647
647
 
648
648
  async def async_exists(self) -> bool:
@@ -653,8 +653,8 @@ class LanceDb(VectorDb):
653
653
  # Check if table exists in database without trying to open it
654
654
  if self.async_connection is None:
655
655
  self.async_connection = await lancedb.connect_async(self.uri)
656
- table_names = await self.async_connection.table_names()
657
- return self.table_name in table_names
656
+ table_list = await self.async_connection.list_tables()
657
+ return self.table_name in table_list.tables
658
658
 
659
659
  async def async_get_count(self) -> int:
660
660
  """Get the number of rows in the table asynchronously."""