agno 2.3.21__py3-none-any.whl → 2.3.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. agno/agent/agent.py +48 -2
  2. agno/agent/remote.py +234 -73
  3. agno/client/a2a/__init__.py +10 -0
  4. agno/client/a2a/client.py +554 -0
  5. agno/client/a2a/schemas.py +112 -0
  6. agno/client/a2a/utils.py +369 -0
  7. agno/db/migrations/utils.py +19 -0
  8. agno/db/migrations/v1_to_v2.py +54 -16
  9. agno/db/migrations/versions/v2_3_0.py +92 -53
  10. agno/db/mysql/async_mysql.py +5 -7
  11. agno/db/mysql/mysql.py +5 -7
  12. agno/db/mysql/schemas.py +39 -21
  13. agno/db/postgres/async_postgres.py +172 -42
  14. agno/db/postgres/postgres.py +186 -38
  15. agno/db/postgres/schemas.py +39 -21
  16. agno/db/postgres/utils.py +6 -2
  17. agno/db/singlestore/schemas.py +41 -21
  18. agno/db/singlestore/singlestore.py +14 -3
  19. agno/db/sqlite/async_sqlite.py +7 -2
  20. agno/db/sqlite/schemas.py +36 -21
  21. agno/db/sqlite/sqlite.py +3 -7
  22. agno/knowledge/chunking/document.py +3 -2
  23. agno/knowledge/chunking/markdown.py +8 -3
  24. agno/knowledge/chunking/recursive.py +2 -2
  25. agno/models/base.py +4 -0
  26. agno/models/google/gemini.py +27 -4
  27. agno/models/openai/chat.py +1 -1
  28. agno/models/openai/responses.py +14 -7
  29. agno/os/middleware/jwt.py +66 -27
  30. agno/os/routers/agents/router.py +3 -3
  31. agno/os/routers/evals/evals.py +2 -2
  32. agno/os/routers/knowledge/knowledge.py +5 -5
  33. agno/os/routers/knowledge/schemas.py +1 -1
  34. agno/os/routers/memory/memory.py +4 -4
  35. agno/os/routers/session/session.py +2 -2
  36. agno/os/routers/teams/router.py +4 -4
  37. agno/os/routers/traces/traces.py +3 -3
  38. agno/os/routers/workflows/router.py +3 -3
  39. agno/os/schema.py +1 -1
  40. agno/reasoning/deepseek.py +11 -1
  41. agno/reasoning/gemini.py +6 -2
  42. agno/reasoning/groq.py +8 -3
  43. agno/reasoning/openai.py +2 -0
  44. agno/remote/base.py +106 -9
  45. agno/skills/__init__.py +17 -0
  46. agno/skills/agent_skills.py +370 -0
  47. agno/skills/errors.py +32 -0
  48. agno/skills/loaders/__init__.py +4 -0
  49. agno/skills/loaders/base.py +27 -0
  50. agno/skills/loaders/local.py +216 -0
  51. agno/skills/skill.py +65 -0
  52. agno/skills/utils.py +107 -0
  53. agno/skills/validator.py +277 -0
  54. agno/team/remote.py +220 -60
  55. agno/team/team.py +41 -3
  56. agno/tools/brandfetch.py +27 -18
  57. agno/tools/browserbase.py +150 -13
  58. agno/tools/function.py +6 -1
  59. agno/tools/mcp/mcp.py +300 -17
  60. agno/tools/mcp/multi_mcp.py +269 -14
  61. agno/tools/toolkit.py +89 -21
  62. agno/utils/mcp.py +49 -8
  63. agno/utils/string.py +43 -1
  64. agno/workflow/condition.py +4 -2
  65. agno/workflow/loop.py +20 -1
  66. agno/workflow/remote.py +173 -33
  67. agno/workflow/router.py +4 -1
  68. agno/workflow/steps.py +4 -0
  69. agno/workflow/workflow.py +14 -0
  70. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/METADATA +13 -14
  71. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/RECORD +74 -60
  72. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/WHEEL +0 -0
  73. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/licenses/LICENSE +0 -0
  74. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from agno.utils.log import log_debug, log_error, log_info, log_warning
31
31
  from agno.utils.string import generate_id
32
32
 
33
33
  try:
34
- from sqlalchemy import Index, UniqueConstraint, and_, func, select, update
34
+ from sqlalchemy import ForeignKey, Index, UniqueConstraint, and_, func, select, update
35
35
  from sqlalchemy.dialects import mysql
36
36
  from sqlalchemy.engine import Engine, create_engine
37
37
  from sqlalchemy.orm import scoped_session, sessionmaker
@@ -151,7 +151,10 @@ class SingleStoreDb(BaseDb):
151
151
  Table: SQLAlchemy Table object with column definitions
152
152
  """
153
153
  try:
154
- table_schema = get_table_schema_definition(table_type)
154
+ # Pass traces_table_name and db_schema for spans table foreign key resolution
155
+ table_schema = get_table_schema_definition(
156
+ table_type, traces_table_name=self.trace_table_name, db_schema=self.db_schema or "agno"
157
+ )
155
158
 
156
159
  columns: List[Column] = []
157
160
  # Get the columns from the table schema
@@ -207,7 +210,10 @@ class SingleStoreDb(BaseDb):
207
210
  """
208
211
  table_ref = f"{self.db_schema}.{table_name}" if self.db_schema else table_name
209
212
  try:
210
- table_schema = get_table_schema_definition(table_type)
213
+ # Pass traces_table_name and db_schema for spans table foreign key resolution
214
+ table_schema = get_table_schema_definition(
215
+ table_type, traces_table_name=self.trace_table_name, db_schema=self.db_schema or "agno"
216
+ ).copy()
211
217
 
212
218
  columns: List[Column] = []
213
219
  indexes: List[str] = []
@@ -227,6 +233,11 @@ class SingleStoreDb(BaseDb):
227
233
  if col_config.get("unique", False):
228
234
  column_kwargs["unique"] = True
229
235
  unique_constraints.append(col_name)
236
+
237
+ # Handle foreign key constraint
238
+ if "foreign_key" in col_config:
239
+ column_args.append(ForeignKey(col_config["foreign_key"]))
240
+
230
241
  columns.append(Column(*column_args, **column_kwargs))
231
242
 
232
243
  # Create the table object
@@ -31,7 +31,7 @@ from agno.utils.log import log_debug, log_error, log_info, log_warning
31
31
  from agno.utils.string import generate_id
32
32
 
33
33
  try:
34
- from sqlalchemy import Column, MetaData, String, Table, func, select, text
34
+ from sqlalchemy import Column, ForeignKey, MetaData, String, Table, func, select, text
35
35
  from sqlalchemy.dialects import sqlite
36
36
  from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine
37
37
  from sqlalchemy.schema import Index, UniqueConstraint
@@ -174,7 +174,8 @@ class AsyncSqliteDb(AsyncBaseDb):
174
174
  Table: SQLAlchemy Table object
175
175
  """
176
176
  try:
177
- table_schema = get_table_schema_definition(table_type)
177
+ # Pass traces_table_name for spans table foreign key resolution
178
+ table_schema = get_table_schema_definition(table_type, traces_table_name=self.trace_table_name).copy()
178
179
 
179
180
  columns: List[Column] = []
180
181
  indexes: List[str] = []
@@ -196,6 +197,10 @@ class AsyncSqliteDb(AsyncBaseDb):
196
197
  column_kwargs["unique"] = True
197
198
  unique_constraints.append(col_name)
198
199
 
200
+ # Handle foreign key constraint
201
+ if "foreign_key" in col_config:
202
+ column_args.append(ForeignKey(col_config["foreign_key"]))
203
+
199
204
  columns.append(Column(*column_args, **column_kwargs)) # type: ignore
200
205
 
201
206
  # Create the table object
agno/db/sqlite/schemas.py CHANGED
@@ -111,25 +111,36 @@ TRACE_TABLE_SCHEMA = {
111
111
  "created_at": {"type": String, "nullable": False, "index": True}, # ISO 8601 datetime string
112
112
  }
113
113
 
114
- SPAN_TABLE_SCHEMA = {
115
- "span_id": {"type": String, "primary_key": True, "nullable": False},
116
- "trace_id": {
117
- "type": String,
118
- "nullable": False,
119
- "index": True,
120
- "foreign_key": "agno_traces.trace_id", # Foreign key to traces table
121
- },
122
- "parent_span_id": {"type": String, "nullable": True, "index": True},
123
- "name": {"type": String, "nullable": False},
124
- "span_kind": {"type": String, "nullable": False},
125
- "status_code": {"type": String, "nullable": False},
126
- "status_message": {"type": String, "nullable": True},
127
- "start_time": {"type": String, "nullable": False, "index": True}, # ISO 8601 datetime string
128
- "end_time": {"type": String, "nullable": False}, # ISO 8601 datetime string
129
- "duration_ms": {"type": BigInteger, "nullable": False},
130
- "attributes": {"type": JSON, "nullable": True},
131
- "created_at": {"type": String, "nullable": False, "index": True}, # ISO 8601 datetime string
132
- }
114
+
115
+ def _get_span_table_schema(traces_table_name: str = "agno_traces") -> dict[str, Any]:
116
+ """Get the span table schema with the correct foreign key reference.
117
+
118
+ Args:
119
+ traces_table_name: The name of the traces table to reference in the foreign key.
120
+
121
+ Returns:
122
+ The span table schema dictionary.
123
+ """
124
+ return {
125
+ "span_id": {"type": String, "primary_key": True, "nullable": False},
126
+ "trace_id": {
127
+ "type": String,
128
+ "nullable": False,
129
+ "index": True,
130
+ "foreign_key": f"{traces_table_name}.trace_id",
131
+ },
132
+ "parent_span_id": {"type": String, "nullable": True, "index": True},
133
+ "name": {"type": String, "nullable": False},
134
+ "span_kind": {"type": String, "nullable": False},
135
+ "status_code": {"type": String, "nullable": False},
136
+ "status_message": {"type": String, "nullable": True},
137
+ "start_time": {"type": String, "nullable": False, "index": True}, # ISO 8601 datetime string
138
+ "end_time": {"type": String, "nullable": False}, # ISO 8601 datetime string
139
+ "duration_ms": {"type": BigInteger, "nullable": False},
140
+ "attributes": {"type": JSON, "nullable": True},
141
+ "created_at": {"type": String, "nullable": False, "index": True}, # ISO 8601 datetime string
142
+ }
143
+
133
144
 
134
145
  CULTURAL_KNOWLEDGE_TABLE_SCHEMA = {
135
146
  "id": {"type": String, "primary_key": True, "nullable": False},
@@ -152,16 +163,21 @@ VERSIONS_TABLE_SCHEMA = {
152
163
  }
153
164
 
154
165
 
155
- def get_table_schema_definition(table_type: str) -> dict[str, Any]:
166
+ def get_table_schema_definition(table_type: str, traces_table_name: str = "agno_traces") -> dict[str, Any]:
156
167
  """
157
168
  Get the expected schema definition for the given table.
158
169
 
159
170
  Args:
160
171
  table_type (str): The type of table to get the schema for.
172
+ traces_table_name (str): The name of the traces table (used for spans foreign key).
161
173
 
162
174
  Returns:
163
175
  Dict[str, Any]: Dictionary containing column definitions for the table
164
176
  """
177
+ # Handle spans table specially to resolve the foreign key reference
178
+ if table_type == "spans":
179
+ return _get_span_table_schema(traces_table_name)
180
+
165
181
  schemas = {
166
182
  "sessions": SESSION_TABLE_SCHEMA,
167
183
  "evals": EVAL_TABLE_SCHEMA,
@@ -169,7 +185,6 @@ def get_table_schema_definition(table_type: str) -> dict[str, Any]:
169
185
  "memories": USER_MEMORY_TABLE_SCHEMA,
170
186
  "knowledge": KNOWLEDGE_TABLE_SCHEMA,
171
187
  "traces": TRACE_TABLE_SCHEMA,
172
- "spans": SPAN_TABLE_SCHEMA,
173
188
  "culture": CULTURAL_KNOWLEDGE_TABLE_SCHEMA,
174
189
  "versions": VERSIONS_TABLE_SCHEMA,
175
190
  }
agno/db/sqlite/sqlite.py CHANGED
@@ -173,7 +173,8 @@ class SqliteDb(BaseDb):
173
173
  Table: SQLAlchemy Table object
174
174
  """
175
175
  try:
176
- table_schema = get_table_schema_definition(table_type).copy()
176
+ # Pass traces_table_name for spans table foreign key resolution
177
+ table_schema = get_table_schema_definition(table_type, traces_table_name=self.trace_table_name).copy()
177
178
 
178
179
  columns: List[Column] = []
179
180
  indexes: List[str] = []
@@ -197,12 +198,7 @@ class SqliteDb(BaseDb):
197
198
 
198
199
  # Handle foreign key constraint
199
200
  if "foreign_key" in col_config:
200
- fk_ref = col_config["foreign_key"]
201
- # For spans table, dynamically replace the traces table reference
202
- # with the actual trace table name configured for this db instance
203
- if table_type == "spans" and "trace_id" in fk_ref:
204
- fk_ref = f"{self.trace_table_name}.trace_id"
205
- column_args.append(ForeignKey(fk_ref))
201
+ column_args.append(ForeignKey(col_config["foreign_key"]))
206
202
 
207
203
  columns.append(Column(*column_args, **column_kwargs)) # type: ignore
208
204
 
@@ -16,8 +16,9 @@ class DocumentChunking(ChunkingStrategy):
16
16
  if len(document.content) <= self.chunk_size:
17
17
  return [document]
18
18
 
19
- # Split on double newlines first (paragraphs)
20
- paragraphs = self.clean_text(document.content).split("\n\n")
19
+ # Split on double newlines first (paragraphs), then clean each paragraph
20
+ raw_paragraphs = document.content.split("\n\n")
21
+ paragraphs = [self.clean_text(para) for para in raw_paragraphs]
21
22
  chunks: List[Document] = []
22
23
  current_chunk = []
23
24
  current_size = 0
@@ -35,7 +35,8 @@ class MarkdownChunking(ChunkingStrategy):
35
35
  elements = partition_md(filename=temp_file_path)
36
36
 
37
37
  if not elements:
38
- return self.clean_text(content).split("\n\n")
38
+ raw_paragraphs = content.split("\n\n")
39
+ return [self.clean_text(para) for para in raw_paragraphs]
39
40
 
40
41
  # Chunk by title with some default values
41
42
  chunked_elements = chunk_by_title(
@@ -57,7 +58,10 @@ class MarkdownChunking(ChunkingStrategy):
57
58
  if chunk_text.strip():
58
59
  text_chunks.append(chunk_text.strip())
59
60
 
60
- return text_chunks if text_chunks else self.clean_text(content).split("\n\n")
61
+ if text_chunks:
62
+ return text_chunks
63
+ raw_paragraphs = content.split("\n\n")
64
+ return [self.clean_text(para) for para in raw_paragraphs]
61
65
 
62
66
  # Always clean up the temporary file
63
67
  finally:
@@ -65,7 +69,8 @@ class MarkdownChunking(ChunkingStrategy):
65
69
 
66
70
  # Fallback to simple paragraph splitting if the markdown chunking fails
67
71
  except Exception:
68
- return self.clean_text(content).split("\n\n")
72
+ raw_paragraphs = content.split("\n\n")
73
+ return [self.clean_text(para) for para in raw_paragraphs]
69
74
 
70
75
  def chunk(self, document: Document) -> List[Document]:
71
76
  """Split markdown document into chunks based on markdown structure"""
@@ -31,7 +31,7 @@ class RecursiveChunking(ChunkingStrategy):
31
31
  start = 0
32
32
  chunk_meta_data = document.meta_data
33
33
  chunk_number = 1
34
- content = self.clean_text(document.content)
34
+ content = document.content
35
35
 
36
36
  while start < len(content):
37
37
  end = min(start + self.chunk_size, len(content))
@@ -43,7 +43,7 @@ class RecursiveChunking(ChunkingStrategy):
43
43
  end = start + last_sep + 1
44
44
  break
45
45
 
46
- chunk = content[start:end]
46
+ chunk = self.clean_text(content[start:end])
47
47
  meta_data = chunk_meta_data.copy()
48
48
  meta_data["chunk"] = chunk_number
49
49
  chunk_id = None
agno/models/base.py CHANGED
@@ -1016,6 +1016,8 @@ class Model(ABC):
1016
1016
  model_response.extra.update(provider_response.extra)
1017
1017
  if provider_response.provider_data is not None:
1018
1018
  model_response.provider_data = provider_response.provider_data
1019
+ if provider_response.response_usage is not None:
1020
+ model_response.response_usage = provider_response.response_usage
1019
1021
 
1020
1022
  async def _aprocess_model_response(
1021
1023
  self,
@@ -1073,6 +1075,8 @@ class Model(ABC):
1073
1075
  model_response.extra.update(provider_response.extra)
1074
1076
  if provider_response.provider_data is not None:
1075
1077
  model_response.provider_data = provider_response.provider_data
1078
+ if provider_response.response_usage is not None:
1079
+ model_response.response_usage = provider_response.response_usage
1076
1080
 
1077
1081
  def _populate_assistant_message(
1078
1082
  self,
@@ -466,7 +466,12 @@ class Gemini(Model):
466
466
 
467
467
  except (ClientError, ServerError) as e:
468
468
  log_error(f"Error from Gemini API: {e}")
469
- error_message = str(e.response) if hasattr(e, "response") else str(e)
469
+ error_message = str(e)
470
+ if hasattr(e, "response"):
471
+ if hasattr(e.response, "text"):
472
+ error_message = e.response.text
473
+ else:
474
+ error_message = str(e.response)
470
475
  raise ModelProviderError(
471
476
  message=error_message,
472
477
  status_code=e.code if hasattr(e, "code") and e.code is not None else 502,
@@ -518,8 +523,14 @@ class Gemini(Model):
518
523
 
519
524
  except (ClientError, ServerError) as e:
520
525
  log_error(f"Error from Gemini API: {e}")
526
+ error_message = str(e)
527
+ if hasattr(e, "response"):
528
+ if hasattr(e.response, "text"):
529
+ error_message = e.response.text
530
+ else:
531
+ error_message = str(e.response)
521
532
  raise ModelProviderError(
522
- message=str(e.response) if hasattr(e, "response") else str(e),
533
+ message=error_message,
523
534
  status_code=e.code if hasattr(e, "code") and e.code is not None else 502,
524
535
  model_name=self.name,
525
536
  model_id=self.id,
@@ -574,8 +585,14 @@ class Gemini(Model):
574
585
 
575
586
  except (ClientError, ServerError) as e:
576
587
  log_error(f"Error from Gemini API: {e}")
588
+ error_message = str(e)
589
+ if hasattr(e, "response"):
590
+ if hasattr(e.response, "text"):
591
+ error_message = e.response.text
592
+ else:
593
+ error_message = str(e.response)
577
594
  raise ModelProviderError(
578
- message=str(e.response) if hasattr(e, "response") else str(e),
595
+ message=error_message,
579
596
  status_code=e.code if hasattr(e, "code") and e.code is not None else 502,
580
597
  model_name=self.name,
581
598
  model_id=self.id,
@@ -628,8 +645,14 @@ class Gemini(Model):
628
645
 
629
646
  except (ClientError, ServerError) as e:
630
647
  log_error(f"Error from Gemini API: {e}")
648
+ error_message = str(e)
649
+ if hasattr(e, "response"):
650
+ if hasattr(e.response, "text"):
651
+ error_message = e.response.text
652
+ else:
653
+ error_message = str(e.response)
631
654
  raise ModelProviderError(
632
- message=str(e.response) if hasattr(e, "response") else str(e),
655
+ message=error_message,
633
656
  status_code=e.code if hasattr(e, "code") and e.code is not None else 502,
634
657
  model_name=self.name,
635
658
  model_id=self.id,
@@ -248,7 +248,7 @@ class OpenAIChat(Model):
248
248
  # Add tools
249
249
  if tools is not None and len(tools) > 0:
250
250
  # Remove unsupported fields for OpenAILike models
251
- if self.provider in ["AIMLAPI", "Fireworks", "Nvidia"]:
251
+ if self.provider in ["AIMLAPI", "Fireworks", "Nvidia", "VLLM"]:
252
252
  for tool in tools:
253
253
  if tool.get("type") == "function":
254
254
  if tool["function"].get("requires_confirmation") is not None:
@@ -13,6 +13,7 @@ from agno.models.message import Citations, Message, UrlCitation
13
13
  from agno.models.metrics import Metrics
14
14
  from agno.models.response import ModelResponse
15
15
  from agno.run.agent import RunOutput
16
+ from agno.tools.function import Function
16
17
  from agno.utils.http import get_default_async_client, get_default_sync_client
17
18
  from agno.utils.log import log_debug, log_error, log_warning
18
19
  from agno.utils.models.openai_responses import images_to_message
@@ -364,19 +365,25 @@ class OpenAIResponses(Model):
364
365
  return vector_store.id
365
366
 
366
367
  def _format_tool_params(
367
- self, messages: List[Message], tools: Optional[List[Dict[str, Any]]] = None
368
+ self, messages: List[Message], tools: Optional[List[Union[Function, Dict[str, Any]]]] = None
368
369
  ) -> List[Dict[str, Any]]:
369
370
  """Format the tool parameters for the OpenAI Responses API."""
370
371
  formatted_tools = []
371
372
  if tools:
372
373
  for _tool in tools:
373
- if _tool.get("type") == "function":
374
+ if isinstance(_tool, Function):
375
+ _tool_dict = _tool.to_dict()
376
+ _tool_dict["type"] = "function"
377
+ for prop in _tool_dict.get("parameters", {}).get("properties", {}).values():
378
+ if isinstance(prop.get("type", ""), list):
379
+ prop["type"] = prop["type"][0]
380
+ formatted_tools.append(_tool_dict)
381
+ elif _tool.get("type") == "function":
374
382
  _tool_dict = _tool.get("function", {})
375
383
  _tool_dict["type"] = "function"
376
384
  for prop in _tool_dict.get("parameters", {}).get("properties", {}).values():
377
385
  if isinstance(prop.get("type", ""), list):
378
386
  prop["type"] = prop["type"][0]
379
-
380
387
  formatted_tools.append(_tool_dict)
381
388
  else:
382
389
  formatted_tools.append(_tool)
@@ -395,7 +402,7 @@ class OpenAIResponses(Model):
395
402
 
396
403
  # Add the file IDs to the tool parameters
397
404
  for _tool in formatted_tools:
398
- if _tool["type"] == "file_search" and vector_store_id is not None:
405
+ if _tool.get("type", "") == "file_search" and vector_store_id is not None:
399
406
  _tool["vector_store_ids"] = [vector_store_id]
400
407
 
401
408
  return formatted_tools
@@ -524,12 +531,12 @@ class OpenAIResponses(Model):
524
531
  def count_tokens(
525
532
  self,
526
533
  messages: List[Message],
527
- tools: Optional[List[Dict[str, Any]]] = None,
534
+ tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
528
535
  output_schema: Optional[Union[Dict, Type[BaseModel]]] = None,
529
536
  ) -> int:
530
537
  try:
531
538
  formatted_input = self._format_messages(messages, compress_tool_results=True)
532
- formatted_tools = self._format_tool_params(messages, tools) if tools else None
539
+ formatted_tools = self._format_tool_params(messages, tools) if tools is not None else None
533
540
 
534
541
  response = self.get_client().responses.input_tokens.count(
535
542
  model=self.id,
@@ -545,7 +552,7 @@ class OpenAIResponses(Model):
545
552
  async def acount_tokens(
546
553
  self,
547
554
  messages: List[Message],
548
- tools: Optional[List[Dict[str, Any]]] = None,
555
+ tools: Optional[List[Union[Function, Dict[str, Any]]]] = None,
549
556
  output_schema: Optional[Union[Dict, Type[BaseModel]]] = None,
550
557
  ) -> int:
551
558
  """Async version of count_tokens using the async client."""
agno/os/middleware/jwt.py CHANGED
@@ -5,7 +5,7 @@ import json
5
5
  import re
6
6
  from enum import Enum
7
7
  from os import getenv
8
- from typing import Any, Dict, List, Optional
8
+ from typing import Any, Dict, Iterable, List, Optional, Union
9
9
 
10
10
  import jwt
11
11
  from fastapi import Request, Response
@@ -168,7 +168,9 @@ class JWTValidator:
168
168
  except Exception as e:
169
169
  log_warning(f"Failed to parse JWKS key: {e}")
170
170
 
171
- def validate_token(self, token: str, expected_audience: Optional[str] = None) -> Dict[str, Any]:
171
+ def validate_token(
172
+ self, token: str, expected_audience: Optional[Union[str, Iterable[str]]] = None
173
+ ) -> Dict[str, Any]:
172
174
  """
173
175
  Validate JWT token and extract claims.
174
176
 
@@ -191,10 +193,9 @@ class JWTValidator:
191
193
  }
192
194
 
193
195
  # Configure audience verification
194
- if expected_audience:
195
- decode_kwargs["audience"] = expected_audience
196
- else:
197
- decode_options["verify_aud"] = False
196
+ # We'll decode without audience verification and if we need to verify the audience,
197
+ # we'll manually verify the audience to provide better error messages
198
+ decode_options["verify_aud"] = False
198
199
 
199
200
  # If validation is disabled, decode without signature verification
200
201
  if not self.validate:
@@ -206,6 +207,7 @@ class JWTValidator:
206
207
  decode_kwargs["options"] = decode_options
207
208
 
208
209
  last_exception: Optional[Exception] = None
210
+ payload: Optional[Dict[str, Any]] = None
209
211
 
210
212
  # Try JWKS keys first if configured
211
213
  if self.jwks_keys:
@@ -222,9 +224,7 @@ class JWTValidator:
222
224
  jwk = self.jwks_keys["_default"]
223
225
 
224
226
  if jwk:
225
- return jwt.decode(token, jwk.key, **decode_kwargs)
226
- except jwt.InvalidAudienceError:
227
- raise
227
+ payload = jwt.decode(token, jwk.key, **decode_kwargs)
228
228
  except jwt.ExpiredSignatureError:
229
229
  raise
230
230
  except jwt.InvalidTokenError as e:
@@ -233,20 +233,54 @@ class JWTValidator:
233
233
  last_exception = e
234
234
 
235
235
  # Try each static verification key until one succeeds
236
- for key in self.verification_keys:
237
- try:
238
- return jwt.decode(token, key, **decode_kwargs)
239
- except jwt.InvalidAudienceError:
240
- raise
241
- except jwt.ExpiredSignatureError:
242
- raise
243
- except jwt.InvalidTokenError as e:
244
- last_exception = e
245
- continue
236
+ if payload is None:
237
+ for key in self.verification_keys:
238
+ try:
239
+ payload = jwt.decode(token, key, **decode_kwargs)
240
+ break
241
+ except jwt.ExpiredSignatureError:
242
+ raise
243
+ except jwt.InvalidTokenError as e:
244
+ last_exception = e
245
+ continue
246
246
 
247
- if last_exception:
248
- raise last_exception
249
- raise jwt.InvalidTokenError("No verification keys configured")
247
+ if payload is None:
248
+ if last_exception:
249
+ raise last_exception
250
+ raise jwt.InvalidTokenError("No verification keys configured")
251
+
252
+ # Manually verify audience if expected_audience was provided
253
+ if expected_audience:
254
+ token_audience = payload.get(self.audience_claim)
255
+ if token_audience is None:
256
+ raise jwt.InvalidTokenError(
257
+ f'Token is missing the "{self.audience_claim}" claim. '
258
+ f"Audience verification requires this claim to be present in the token."
259
+ )
260
+
261
+ # Normalize expected_audience to a list
262
+ if isinstance(expected_audience, str):
263
+ expected_audiences = [expected_audience]
264
+ elif isinstance(expected_audience, Iterable):
265
+ expected_audiences = list(expected_audience)
266
+ else:
267
+ expected_audiences = []
268
+
269
+ # Normalize token_audience to a list
270
+ if isinstance(token_audience, str):
271
+ token_audiences = [token_audience]
272
+ elif isinstance(token_audience, list):
273
+ token_audiences = token_audience
274
+ else:
275
+ token_audiences = [token_audience] if token_audience else []
276
+
277
+ # Check if any token audience matches any expected audience
278
+ if not any(aud in expected_audiences for aud in token_audiences):
279
+ raise jwt.InvalidAudienceError(
280
+ f"Invalid audience. Expected one of: {expected_audiences}, got: {token_audiences}"
281
+ )
282
+
283
+ return payload
250
284
 
251
285
  def extract_claims(self, payload: Dict[str, Any]) -> Dict[str, Any]:
252
286
  """
@@ -364,6 +398,7 @@ class JWTMiddleware(BaseHTTPMiddleware):
364
398
  user_id_claim: str = "sub",
365
399
  session_id_claim: str = "session_id",
366
400
  audience_claim: str = "aud",
401
+ audience: Optional[Union[str, Iterable[str]]] = None,
367
402
  verify_audience: bool = False,
368
403
  dependencies_claims: Optional[List[str]] = None,
369
404
  session_state_claims: Optional[List[str]] = None,
@@ -400,7 +435,8 @@ class JWTMiddleware(BaseHTTPMiddleware):
400
435
  user_id_claim: JWT claim name for user ID (default: "sub")
401
436
  session_id_claim: JWT claim name for session ID (default: "session_id")
402
437
  audience_claim: JWT claim name for audience/OS ID (default: "aud")
403
- verify_audience: Whether to verify the audience claim matches AgentOS ID (default: False)
438
+ audience: Optional expected audience claim to validate against the token's audience claim (default: AgentOS ID)
439
+ verify_audience: Whether to verify the token's audience claim matches the expected audience claim (default: False)
404
440
  dependencies_claims: A list of claims to extract from the JWT token for dependencies
405
441
  session_state_claims: A list of claims to extract from the JWT token for session state
406
442
  scope_mappings: Optional dictionary mapping route patterns to required scopes.
@@ -453,6 +489,8 @@ class JWTMiddleware(BaseHTTPMiddleware):
453
489
  self.dependencies_claims: List[str] = dependencies_claims or []
454
490
  self.session_state_claims: List[str] = session_state_claims or []
455
491
 
492
+ self.audience = audience
493
+
456
494
  # RBAC configuration (opt-in via scope_mappings)
457
495
  self.authorization = authorization
458
496
 
@@ -648,7 +686,9 @@ class JWTMiddleware(BaseHTTPMiddleware):
648
686
 
649
687
  try:
650
688
  # Validate token and extract claims (with audience verification if configured)
651
- expected_audience = agent_os_id if self.verify_audience else None
689
+ expected_audience = None
690
+ if self.verify_audience:
691
+ expected_audience = self.audience or agent_os_id
652
692
  payload: Dict[str, Any] = self.validator.validate_token(token, expected_audience) # type: ignore
653
693
 
654
694
  # Extract standard claims and store in request.state
@@ -755,11 +795,10 @@ class JWTMiddleware(BaseHTTPMiddleware):
755
795
  request.state.authenticated = True
756
796
 
757
797
  except jwt.InvalidAudienceError:
758
- log_warning(f"Invalid audience - expected: {agent_os_id}")
798
+ log_warning(f"Invalid token audience - expected: {expected_audience}")
759
799
  return self._create_error_response(
760
- 401, "Invalid audience - token not valid for this AgentOS instance", origin, cors_allowed_origins
800
+ 401, "Invalid token audience - token not valid for this AgentOS instance", origin, cors_allowed_origins
761
801
  )
762
-
763
802
  except jwt.ExpiredSignatureError as e:
764
803
  if self.validate:
765
804
  log_warning(f"Token has expired: {str(e)}")
@@ -220,11 +220,11 @@ def get_agent_router(
220
220
  kwargs = await get_request_kwargs(request, create_agent_run)
221
221
 
222
222
  if hasattr(request.state, "user_id") and request.state.user_id is not None:
223
- if user_id:
223
+ if user_id and user_id != request.state.user_id:
224
224
  log_warning("User ID parameter passed in both request state and kwargs, using request state")
225
225
  user_id = request.state.user_id
226
226
  if hasattr(request.state, "session_id") and request.state.session_id is not None:
227
- if session_id:
227
+ if session_id and session_id != request.state.session_id:
228
228
  log_warning("Session ID parameter passed in both request state and kwargs, using request state")
229
229
  session_id = request.state.session_id
230
230
  if hasattr(request.state, "session_state") and request.state.session_state is not None:
@@ -409,7 +409,7 @@ def get_agent_router(
409
409
  if agent is None:
410
410
  raise HTTPException(status_code=404, detail="Agent not found")
411
411
 
412
- cancelled = agent.cancel_run(run_id=run_id)
412
+ cancelled = await agent.acancel_run(run_id=run_id)
413
413
  if not cancelled:
414
414
  raise HTTPException(status_code=500, detail="Failed to cancel run - run not found or already completed")
415
415
 
@@ -118,8 +118,8 @@ def attach_routes(
118
118
  model_id: Optional[str] = Query(default=None, description="Model ID"),
119
119
  filter_type: Optional[EvalFilterType] = Query(default=None, description="Filter type", alias="type"),
120
120
  eval_types: Optional[List[EvalType]] = Depends(parse_eval_types_filter),
121
- limit: Optional[int] = Query(default=20, description="Number of eval runs to return"),
122
- page: Optional[int] = Query(default=1, description="Page number"),
121
+ limit: Optional[int] = Query(default=20, description="Number of eval runs to return", ge=1),
122
+ page: Optional[int] = Query(default=1, description="Page number", ge=0),
123
123
  sort_by: Optional[str] = Query(default="created_at", description="Field to sort by"),
124
124
  sort_order: Optional[SortOrder] = Query(default="desc", description="Sort order (asc or desc)"),
125
125
  db_id: Optional[str] = Query(default=None, description="The ID of the database to use"),
@@ -297,7 +297,7 @@ def attach_routes(router: APIRouter, knowledge_instances: List[Union[Knowledge,
297
297
  else:
298
298
  raise HTTPException(status_code=400, detail=f"Invalid reader_id: {update_data.reader_id}")
299
299
 
300
- updated_content_dict = knowledge.patch_content(content)
300
+ updated_content_dict = await knowledge.apatch_content(content)
301
301
  if not updated_content_dict:
302
302
  raise HTTPException(status_code=404, detail=f"Content not found: {content_id}")
303
303
 
@@ -344,8 +344,8 @@ def attach_routes(router: APIRouter, knowledge_instances: List[Union[Knowledge,
344
344
  )
345
345
  async def get_content(
346
346
  request: Request,
347
- limit: Optional[int] = Query(default=20, description="Number of content entries to return"),
348
- page: Optional[int] = Query(default=1, description="Page number"),
347
+ limit: Optional[int] = Query(default=20, description="Number of content entries to return", ge=1),
348
+ page: Optional[int] = Query(default=1, description="Page number", ge=0),
349
349
  sort_by: Optional[str] = Query(default="created_at", description="Field to sort by"),
350
350
  sort_order: Optional[SortOrder] = Query(default="desc", description="Sort order (asc or desc)"),
351
351
  db_id: Optional[str] = Query(default=None, description="The ID of the database to use"),
@@ -1029,13 +1029,13 @@ def attach_routes(router: APIRouter, knowledge_instances: List[Union[Knowledge,
1029
1029
  search_types=search_types,
1030
1030
  )
1031
1031
  )
1032
-
1032
+ filters = await knowledge.async_get_valid_filters()
1033
1033
  return ConfigResponseSchema(
1034
1034
  readers=reader_schemas,
1035
1035
  vector_dbs=vector_dbs,
1036
1036
  readersForType=types_of_readers,
1037
1037
  chunkers=chunkers_dict,
1038
- filters=knowledge.get_valid_filters(),
1038
+ filters=filters,
1039
1039
  )
1040
1040
 
1041
1041
  return router