agno 2.0.1__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +6015 -2823
- agno/api/api.py +2 -0
- agno/api/os.py +1 -1
- agno/culture/__init__.py +3 -0
- agno/culture/manager.py +956 -0
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/base.py +385 -6
- agno/db/dynamo/dynamo.py +388 -81
- agno/db/dynamo/schemas.py +47 -10
- agno/db/dynamo/utils.py +63 -4
- agno/db/firestore/firestore.py +435 -64
- agno/db/firestore/schemas.py +11 -0
- agno/db/firestore/utils.py +102 -4
- agno/db/gcs_json/gcs_json_db.py +384 -42
- agno/db/gcs_json/utils.py +60 -26
- agno/db/in_memory/in_memory_db.py +351 -66
- agno/db/in_memory/utils.py +60 -2
- agno/db/json/json_db.py +339 -48
- agno/db/json/utils.py +60 -26
- agno/db/migrations/manager.py +199 -0
- agno/db/migrations/v1_to_v2.py +510 -37
- agno/db/migrations/versions/__init__.py +0 -0
- agno/db/migrations/versions/v2_3_0.py +938 -0
- agno/db/mongo/__init__.py +15 -1
- agno/db/mongo/async_mongo.py +2036 -0
- agno/db/mongo/mongo.py +653 -76
- agno/db/mongo/schemas.py +13 -0
- agno/db/mongo/utils.py +80 -8
- agno/db/mysql/mysql.py +687 -25
- agno/db/mysql/schemas.py +61 -37
- agno/db/mysql/utils.py +60 -2
- agno/db/postgres/__init__.py +2 -1
- agno/db/postgres/async_postgres.py +2001 -0
- agno/db/postgres/postgres.py +676 -57
- agno/db/postgres/schemas.py +43 -18
- agno/db/postgres/utils.py +164 -2
- agno/db/redis/redis.py +344 -38
- agno/db/redis/schemas.py +18 -0
- agno/db/redis/utils.py +60 -2
- agno/db/schemas/__init__.py +2 -1
- agno/db/schemas/culture.py +120 -0
- agno/db/schemas/memory.py +13 -0
- agno/db/singlestore/schemas.py +26 -1
- agno/db/singlestore/singlestore.py +687 -53
- agno/db/singlestore/utils.py +60 -2
- agno/db/sqlite/__init__.py +2 -1
- agno/db/sqlite/async_sqlite.py +2371 -0
- agno/db/sqlite/schemas.py +24 -0
- agno/db/sqlite/sqlite.py +774 -85
- agno/db/sqlite/utils.py +168 -5
- agno/db/surrealdb/__init__.py +3 -0
- agno/db/surrealdb/metrics.py +292 -0
- agno/db/surrealdb/models.py +309 -0
- agno/db/surrealdb/queries.py +71 -0
- agno/db/surrealdb/surrealdb.py +1361 -0
- agno/db/surrealdb/utils.py +147 -0
- agno/db/utils.py +50 -22
- agno/eval/accuracy.py +50 -43
- agno/eval/performance.py +6 -3
- agno/eval/reliability.py +6 -3
- agno/eval/utils.py +33 -16
- agno/exceptions.py +68 -1
- agno/filters.py +354 -0
- agno/guardrails/__init__.py +6 -0
- agno/guardrails/base.py +19 -0
- agno/guardrails/openai.py +144 -0
- agno/guardrails/pii.py +94 -0
- agno/guardrails/prompt_injection.py +52 -0
- agno/integrations/discord/client.py +1 -0
- agno/knowledge/chunking/agentic.py +13 -10
- agno/knowledge/chunking/fixed.py +1 -1
- agno/knowledge/chunking/semantic.py +40 -8
- agno/knowledge/chunking/strategy.py +59 -15
- agno/knowledge/embedder/aws_bedrock.py +9 -4
- agno/knowledge/embedder/azure_openai.py +54 -0
- agno/knowledge/embedder/base.py +2 -0
- agno/knowledge/embedder/cohere.py +184 -5
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/embedder/google.py +79 -1
- agno/knowledge/embedder/huggingface.py +9 -4
- agno/knowledge/embedder/jina.py +63 -0
- agno/knowledge/embedder/mistral.py +78 -11
- agno/knowledge/embedder/nebius.py +1 -1
- agno/knowledge/embedder/ollama.py +13 -0
- agno/knowledge/embedder/openai.py +37 -65
- agno/knowledge/embedder/sentence_transformer.py +8 -4
- agno/knowledge/embedder/vllm.py +262 -0
- agno/knowledge/embedder/voyageai.py +69 -16
- agno/knowledge/knowledge.py +594 -186
- agno/knowledge/reader/base.py +9 -2
- agno/knowledge/reader/csv_reader.py +8 -10
- agno/knowledge/reader/docx_reader.py +5 -6
- agno/knowledge/reader/field_labeled_csv_reader.py +290 -0
- agno/knowledge/reader/json_reader.py +6 -5
- agno/knowledge/reader/markdown_reader.py +13 -13
- agno/knowledge/reader/pdf_reader.py +43 -68
- agno/knowledge/reader/pptx_reader.py +101 -0
- agno/knowledge/reader/reader_factory.py +51 -6
- agno/knowledge/reader/s3_reader.py +3 -15
- agno/knowledge/reader/tavily_reader.py +194 -0
- agno/knowledge/reader/text_reader.py +13 -13
- agno/knowledge/reader/web_search_reader.py +2 -43
- agno/knowledge/reader/website_reader.py +43 -25
- agno/knowledge/reranker/__init__.py +2 -8
- agno/knowledge/types.py +9 -0
- agno/knowledge/utils.py +20 -0
- agno/media.py +72 -0
- agno/memory/manager.py +336 -82
- agno/models/aimlapi/aimlapi.py +2 -2
- agno/models/anthropic/claude.py +183 -37
- agno/models/aws/bedrock.py +52 -112
- agno/models/aws/claude.py +33 -1
- agno/models/azure/ai_foundry.py +33 -15
- agno/models/azure/openai_chat.py +25 -8
- agno/models/base.py +999 -519
- agno/models/cerebras/cerebras.py +19 -13
- agno/models/cerebras/cerebras_openai.py +8 -5
- agno/models/cohere/chat.py +27 -1
- agno/models/cometapi/__init__.py +5 -0
- agno/models/cometapi/cometapi.py +57 -0
- agno/models/dashscope/dashscope.py +1 -0
- agno/models/deepinfra/deepinfra.py +2 -2
- agno/models/deepseek/deepseek.py +2 -2
- agno/models/fireworks/fireworks.py +2 -2
- agno/models/google/gemini.py +103 -31
- agno/models/groq/groq.py +28 -11
- agno/models/huggingface/huggingface.py +2 -1
- agno/models/internlm/internlm.py +2 -2
- agno/models/langdb/langdb.py +4 -4
- agno/models/litellm/chat.py +18 -1
- agno/models/litellm/litellm_openai.py +2 -2
- agno/models/llama_cpp/__init__.py +5 -0
- agno/models/llama_cpp/llama_cpp.py +22 -0
- agno/models/message.py +139 -0
- agno/models/meta/llama.py +27 -10
- agno/models/meta/llama_openai.py +5 -17
- agno/models/nebius/nebius.py +6 -6
- agno/models/nexus/__init__.py +3 -0
- agno/models/nexus/nexus.py +22 -0
- agno/models/nvidia/nvidia.py +2 -2
- agno/models/ollama/chat.py +59 -5
- agno/models/openai/chat.py +69 -29
- agno/models/openai/responses.py +103 -106
- agno/models/openrouter/openrouter.py +41 -3
- agno/models/perplexity/perplexity.py +4 -5
- agno/models/portkey/portkey.py +3 -3
- agno/models/requesty/__init__.py +5 -0
- agno/models/requesty/requesty.py +52 -0
- agno/models/response.py +77 -1
- agno/models/sambanova/sambanova.py +2 -2
- agno/models/siliconflow/__init__.py +5 -0
- agno/models/siliconflow/siliconflow.py +25 -0
- agno/models/together/together.py +2 -2
- agno/models/utils.py +254 -8
- agno/models/vercel/v0.py +2 -2
- agno/models/vertexai/__init__.py +0 -0
- agno/models/vertexai/claude.py +96 -0
- agno/models/vllm/vllm.py +1 -0
- agno/models/xai/xai.py +3 -2
- agno/os/app.py +543 -178
- agno/os/auth.py +24 -14
- agno/os/config.py +1 -0
- agno/os/interfaces/__init__.py +1 -0
- agno/os/interfaces/a2a/__init__.py +3 -0
- agno/os/interfaces/a2a/a2a.py +42 -0
- agno/os/interfaces/a2a/router.py +250 -0
- agno/os/interfaces/a2a/utils.py +924 -0
- agno/os/interfaces/agui/agui.py +23 -7
- agno/os/interfaces/agui/router.py +27 -3
- agno/os/interfaces/agui/utils.py +242 -142
- agno/os/interfaces/base.py +6 -2
- agno/os/interfaces/slack/router.py +81 -23
- agno/os/interfaces/slack/slack.py +29 -14
- agno/os/interfaces/whatsapp/router.py +11 -4
- agno/os/interfaces/whatsapp/whatsapp.py +14 -7
- agno/os/mcp.py +111 -54
- agno/os/middleware/__init__.py +7 -0
- agno/os/middleware/jwt.py +233 -0
- agno/os/router.py +556 -139
- agno/os/routers/evals/evals.py +71 -34
- agno/os/routers/evals/schemas.py +31 -31
- agno/os/routers/evals/utils.py +6 -5
- agno/os/routers/health.py +31 -0
- agno/os/routers/home.py +52 -0
- agno/os/routers/knowledge/knowledge.py +185 -38
- agno/os/routers/knowledge/schemas.py +82 -22
- agno/os/routers/memory/memory.py +158 -53
- agno/os/routers/memory/schemas.py +20 -16
- agno/os/routers/metrics/metrics.py +20 -8
- agno/os/routers/metrics/schemas.py +16 -16
- agno/os/routers/session/session.py +499 -38
- agno/os/schema.py +308 -198
- agno/os/utils.py +401 -41
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/azure_ai_foundry.py +2 -2
- agno/reasoning/deepseek.py +2 -2
- agno/reasoning/default.py +3 -1
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/groq.py +2 -2
- agno/reasoning/ollama.py +2 -2
- agno/reasoning/openai.py +7 -2
- agno/reasoning/vertexai.py +76 -0
- agno/run/__init__.py +6 -0
- agno/run/agent.py +248 -94
- agno/run/base.py +44 -5
- agno/run/team.py +238 -97
- agno/run/workflow.py +144 -33
- agno/session/agent.py +105 -89
- agno/session/summary.py +65 -25
- agno/session/team.py +176 -96
- agno/session/workflow.py +406 -40
- agno/team/team.py +3854 -1610
- agno/tools/dalle.py +2 -4
- agno/tools/decorator.py +4 -2
- agno/tools/duckduckgo.py +15 -11
- agno/tools/e2b.py +14 -7
- agno/tools/eleven_labs.py +23 -25
- agno/tools/exa.py +21 -16
- agno/tools/file.py +153 -23
- agno/tools/file_generation.py +350 -0
- agno/tools/firecrawl.py +4 -4
- agno/tools/function.py +250 -30
- agno/tools/gmail.py +238 -14
- agno/tools/google_drive.py +270 -0
- agno/tools/googlecalendar.py +36 -8
- agno/tools/googlesheets.py +20 -5
- agno/tools/jira.py +20 -0
- agno/tools/knowledge.py +3 -3
- agno/tools/mcp/__init__.py +10 -0
- agno/tools/mcp/mcp.py +331 -0
- agno/tools/mcp/multi_mcp.py +347 -0
- agno/tools/mcp/params.py +24 -0
- agno/tools/mcp_toolbox.py +284 -0
- agno/tools/mem0.py +11 -17
- agno/tools/memori.py +1 -53
- agno/tools/memory.py +419 -0
- agno/tools/models/nebius.py +5 -5
- agno/tools/models_labs.py +20 -10
- agno/tools/notion.py +204 -0
- agno/tools/parallel.py +314 -0
- agno/tools/scrapegraph.py +58 -31
- agno/tools/searxng.py +2 -2
- agno/tools/serper.py +2 -2
- agno/tools/slack.py +18 -3
- agno/tools/spider.py +2 -2
- agno/tools/tavily.py +146 -0
- agno/tools/whatsapp.py +1 -1
- agno/tools/workflow.py +278 -0
- agno/tools/yfinance.py +12 -11
- agno/utils/agent.py +820 -0
- agno/utils/audio.py +27 -0
- agno/utils/common.py +90 -1
- agno/utils/events.py +217 -2
- agno/utils/gemini.py +180 -22
- agno/utils/hooks.py +57 -0
- agno/utils/http.py +111 -0
- agno/utils/knowledge.py +12 -5
- agno/utils/log.py +1 -0
- agno/utils/mcp.py +92 -2
- agno/utils/media.py +188 -10
- agno/utils/merge_dict.py +22 -1
- agno/utils/message.py +60 -0
- agno/utils/models/claude.py +40 -11
- agno/utils/print_response/agent.py +105 -21
- agno/utils/print_response/team.py +103 -38
- agno/utils/print_response/workflow.py +251 -34
- agno/utils/reasoning.py +22 -1
- agno/utils/serialize.py +32 -0
- agno/utils/streamlit.py +16 -10
- agno/utils/string.py +41 -0
- agno/utils/team.py +98 -9
- agno/utils/tools.py +1 -1
- agno/vectordb/base.py +23 -4
- agno/vectordb/cassandra/cassandra.py +65 -9
- agno/vectordb/chroma/chromadb.py +182 -38
- agno/vectordb/clickhouse/clickhousedb.py +64 -11
- agno/vectordb/couchbase/couchbase.py +105 -10
- agno/vectordb/lancedb/lance_db.py +124 -133
- agno/vectordb/langchaindb/langchaindb.py +25 -7
- agno/vectordb/lightrag/lightrag.py +17 -3
- agno/vectordb/llamaindex/__init__.py +3 -0
- agno/vectordb/llamaindex/llamaindexdb.py +46 -7
- agno/vectordb/milvus/milvus.py +126 -9
- agno/vectordb/mongodb/__init__.py +7 -1
- agno/vectordb/mongodb/mongodb.py +112 -7
- agno/vectordb/pgvector/pgvector.py +142 -21
- agno/vectordb/pineconedb/pineconedb.py +80 -8
- agno/vectordb/qdrant/qdrant.py +125 -39
- agno/vectordb/redis/__init__.py +9 -0
- agno/vectordb/redis/redisdb.py +694 -0
- agno/vectordb/singlestore/singlestore.py +111 -25
- agno/vectordb/surrealdb/surrealdb.py +31 -5
- agno/vectordb/upstashdb/upstashdb.py +76 -8
- agno/vectordb/weaviate/weaviate.py +86 -15
- agno/workflow/__init__.py +2 -0
- agno/workflow/agent.py +299 -0
- agno/workflow/condition.py +112 -18
- agno/workflow/loop.py +69 -10
- agno/workflow/parallel.py +266 -118
- agno/workflow/router.py +110 -17
- agno/workflow/step.py +638 -129
- agno/workflow/steps.py +65 -6
- agno/workflow/types.py +61 -23
- agno/workflow/workflow.py +2085 -272
- {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/METADATA +182 -58
- agno-2.3.0.dist-info/RECORD +577 -0
- agno/knowledge/reader/url_reader.py +0 -128
- agno/tools/googlesearch.py +0 -98
- agno/tools/mcp.py +0 -610
- agno/utils/models/aws_claude.py +0 -170
- agno-2.0.1.dist-info/RECORD +0 -515
- {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/WHEEL +0 -0
- {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/licenses/LICENSE +0 -0
- {agno-2.0.1.dist-info → agno-2.3.0.dist-info}/top_level.txt +0 -0
agno/os/router.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, List, Optional, Union, cast
|
|
2
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Union, cast
|
|
3
3
|
from uuid import uuid4
|
|
4
4
|
|
|
5
5
|
from fastapi import (
|
|
@@ -8,22 +8,26 @@ from fastapi import (
|
|
|
8
8
|
File,
|
|
9
9
|
Form,
|
|
10
10
|
HTTPException,
|
|
11
|
+
Request,
|
|
11
12
|
UploadFile,
|
|
12
13
|
WebSocket,
|
|
13
14
|
)
|
|
14
15
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
16
|
+
from packaging import version
|
|
15
17
|
from pydantic import BaseModel
|
|
16
18
|
|
|
17
19
|
from agno.agent.agent import Agent
|
|
20
|
+
from agno.db.base import AsyncBaseDb
|
|
21
|
+
from agno.db.migrations.manager import MigrationManager
|
|
22
|
+
from agno.exceptions import InputCheckError, OutputCheckError
|
|
18
23
|
from agno.media import Audio, Image, Video
|
|
19
24
|
from agno.media import File as FileMedia
|
|
20
|
-
from agno.os.auth import get_authentication_dependency
|
|
25
|
+
from agno.os.auth import get_authentication_dependency, validate_websocket_token
|
|
21
26
|
from agno.os.schema import (
|
|
22
27
|
AgentResponse,
|
|
23
28
|
AgentSummaryResponse,
|
|
24
29
|
BadRequestResponse,
|
|
25
30
|
ConfigResponse,
|
|
26
|
-
HealthResponse,
|
|
27
31
|
InterfaceResponse,
|
|
28
32
|
InternalServerErrorResponse,
|
|
29
33
|
Model,
|
|
@@ -38,6 +42,7 @@ from agno.os.schema import (
|
|
|
38
42
|
from agno.os.settings import AgnoAPISettings
|
|
39
43
|
from agno.os.utils import (
|
|
40
44
|
get_agent_by_id,
|
|
45
|
+
get_db,
|
|
41
46
|
get_team_by_id,
|
|
42
47
|
get_workflow_by_id,
|
|
43
48
|
process_audio,
|
|
@@ -45,9 +50,10 @@ from agno.os.utils import (
|
|
|
45
50
|
process_image,
|
|
46
51
|
process_video,
|
|
47
52
|
)
|
|
48
|
-
from agno.run.agent import RunErrorEvent, RunOutput
|
|
53
|
+
from agno.run.agent import RunErrorEvent, RunOutput, RunOutputEvent
|
|
49
54
|
from agno.run.team import RunErrorEvent as TeamRunErrorEvent
|
|
50
|
-
from agno.run.
|
|
55
|
+
from agno.run.team import TeamRunOutputEvent
|
|
56
|
+
from agno.run.workflow import WorkflowErrorEvent, WorkflowRunOutput, WorkflowRunOutputEvent
|
|
51
57
|
from agno.team.team import Team
|
|
52
58
|
from agno.utils.log import log_debug, log_error, log_warning, logger
|
|
53
59
|
from agno.workflow.workflow import Workflow
|
|
@@ -56,11 +62,98 @@ if TYPE_CHECKING:
|
|
|
56
62
|
from agno.os.app import AgentOS
|
|
57
63
|
|
|
58
64
|
|
|
59
|
-
def
|
|
65
|
+
async def _get_request_kwargs(request: Request, endpoint_func: Callable) -> Dict[str, Any]:
|
|
66
|
+
"""Given a Request and an endpoint function, return a dictionary with all extra form data fields.
|
|
67
|
+
Args:
|
|
68
|
+
request: The FastAPI Request object
|
|
69
|
+
endpoint_func: The function exposing the endpoint that received the request
|
|
70
|
+
|
|
71
|
+
Returns:
|
|
72
|
+
A dictionary of kwargs
|
|
73
|
+
"""
|
|
74
|
+
import inspect
|
|
75
|
+
|
|
76
|
+
form_data = await request.form()
|
|
77
|
+
sig = inspect.signature(endpoint_func)
|
|
78
|
+
known_fields = set(sig.parameters.keys())
|
|
79
|
+
kwargs: Dict[str, Any] = {key: value for key, value in form_data.items() if key not in known_fields}
|
|
80
|
+
|
|
81
|
+
# Handle JSON parameters. They are passed as strings and need to be deserialized.
|
|
82
|
+
if session_state := kwargs.get("session_state"):
|
|
83
|
+
try:
|
|
84
|
+
if isinstance(session_state, str):
|
|
85
|
+
session_state_dict = json.loads(session_state) # type: ignore
|
|
86
|
+
kwargs["session_state"] = session_state_dict
|
|
87
|
+
except json.JSONDecodeError:
|
|
88
|
+
kwargs.pop("session_state")
|
|
89
|
+
log_warning(f"Invalid session_state parameter couldn't be loaded: {session_state}")
|
|
90
|
+
|
|
91
|
+
if dependencies := kwargs.get("dependencies"):
|
|
92
|
+
try:
|
|
93
|
+
if isinstance(dependencies, str):
|
|
94
|
+
dependencies_dict = json.loads(dependencies) # type: ignore
|
|
95
|
+
kwargs["dependencies"] = dependencies_dict
|
|
96
|
+
except json.JSONDecodeError:
|
|
97
|
+
kwargs.pop("dependencies")
|
|
98
|
+
log_warning(f"Invalid dependencies parameter couldn't be loaded: {dependencies}")
|
|
99
|
+
|
|
100
|
+
if metadata := kwargs.get("metadata"):
|
|
101
|
+
try:
|
|
102
|
+
if isinstance(metadata, str):
|
|
103
|
+
metadata_dict = json.loads(metadata) # type: ignore
|
|
104
|
+
kwargs["metadata"] = metadata_dict
|
|
105
|
+
except json.JSONDecodeError:
|
|
106
|
+
kwargs.pop("metadata")
|
|
107
|
+
log_warning(f"Invalid metadata parameter couldn't be loaded: {metadata}")
|
|
108
|
+
|
|
109
|
+
if knowledge_filters := kwargs.get("knowledge_filters"):
|
|
110
|
+
try:
|
|
111
|
+
if isinstance(knowledge_filters, str):
|
|
112
|
+
knowledge_filters_dict = json.loads(knowledge_filters) # type: ignore
|
|
113
|
+
|
|
114
|
+
# Try to deserialize FilterExpr objects
|
|
115
|
+
from agno.filters import from_dict
|
|
116
|
+
|
|
117
|
+
# Check if it's a single FilterExpr dict or a list of FilterExpr dicts
|
|
118
|
+
if isinstance(knowledge_filters_dict, dict) and "op" in knowledge_filters_dict:
|
|
119
|
+
# Single FilterExpr - convert to list format
|
|
120
|
+
kwargs["knowledge_filters"] = [from_dict(knowledge_filters_dict)]
|
|
121
|
+
elif isinstance(knowledge_filters_dict, list):
|
|
122
|
+
# List of FilterExprs or mixed content
|
|
123
|
+
deserialized = []
|
|
124
|
+
for item in knowledge_filters_dict:
|
|
125
|
+
if isinstance(item, dict) and "op" in item:
|
|
126
|
+
deserialized.append(from_dict(item))
|
|
127
|
+
else:
|
|
128
|
+
# Keep non-FilterExpr items as-is
|
|
129
|
+
deserialized.append(item)
|
|
130
|
+
kwargs["knowledge_filters"] = deserialized
|
|
131
|
+
else:
|
|
132
|
+
# Regular dict filter
|
|
133
|
+
kwargs["knowledge_filters"] = knowledge_filters_dict
|
|
134
|
+
except json.JSONDecodeError:
|
|
135
|
+
kwargs.pop("knowledge_filters")
|
|
136
|
+
log_warning(f"Invalid knowledge_filters parameter couldn't be loaded: {knowledge_filters}")
|
|
137
|
+
except ValueError as e:
|
|
138
|
+
# Filter deserialization failed
|
|
139
|
+
kwargs.pop("knowledge_filters")
|
|
140
|
+
log_warning(f"Invalid FilterExpr in knowledge_filters: {e}")
|
|
141
|
+
|
|
142
|
+
# Parse boolean and null values
|
|
143
|
+
for key, value in kwargs.items():
|
|
144
|
+
if isinstance(value, str) and value.lower() in ["true", "false"]:
|
|
145
|
+
kwargs[key] = value.lower() == "true"
|
|
146
|
+
elif isinstance(value, str) and value.lower() in ["null", "none"]:
|
|
147
|
+
kwargs[key] = None
|
|
148
|
+
|
|
149
|
+
return kwargs
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def format_sse_event(event: Union[RunOutputEvent, TeamRunOutputEvent, WorkflowRunOutputEvent]) -> str:
|
|
60
153
|
"""Parse JSON data into SSE-compliant format.
|
|
61
154
|
|
|
62
155
|
Args:
|
|
63
|
-
|
|
156
|
+
event_dict: Dictionary containing the event data
|
|
64
157
|
|
|
65
158
|
Returns:
|
|
66
159
|
SSE-formatted response:
|
|
@@ -75,20 +168,22 @@ def format_sse_event(json_data: str) -> str:
|
|
|
75
168
|
"""
|
|
76
169
|
try:
|
|
77
170
|
# Parse the JSON to extract the event type
|
|
78
|
-
|
|
79
|
-
|
|
171
|
+
event_type = event.event or "message"
|
|
172
|
+
|
|
173
|
+
# Serialize to valid JSON with double quotes and no newlines
|
|
174
|
+
clean_json = event.to_json(separators=(",", ":"), indent=None)
|
|
80
175
|
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
return f"event: message\ndata: {json_data}\n\n"
|
|
176
|
+
return f"event: {event_type}\ndata: {clean_json}\n\n"
|
|
177
|
+
except json.JSONDecodeError:
|
|
178
|
+
clean_json = event.to_json(separators=(",", ":"), indent=None)
|
|
179
|
+
return f"event: message\ndata: {clean_json}\n\n"
|
|
86
180
|
|
|
87
181
|
|
|
88
182
|
class WebSocketManager:
|
|
89
183
|
"""Manages WebSocket connections for workflow runs"""
|
|
90
184
|
|
|
91
185
|
active_connections: Dict[str, WebSocket] # {run_id: websocket}
|
|
186
|
+
authenticated_connections: Dict[WebSocket, bool] # {websocket: is_authenticated}
|
|
92
187
|
|
|
93
188
|
def __init__(
|
|
94
189
|
self,
|
|
@@ -96,22 +191,51 @@ class WebSocketManager:
|
|
|
96
191
|
):
|
|
97
192
|
# Store active connections: {run_id: websocket}
|
|
98
193
|
self.active_connections = active_connections or {}
|
|
194
|
+
# Track authentication state for each websocket
|
|
195
|
+
self.authenticated_connections = {}
|
|
99
196
|
|
|
100
|
-
async def connect(self, websocket: WebSocket):
|
|
197
|
+
async def connect(self, websocket: WebSocket, requires_auth: bool = True):
|
|
101
198
|
"""Accept WebSocket connection"""
|
|
102
199
|
await websocket.accept()
|
|
103
200
|
logger.debug("WebSocket connected")
|
|
104
201
|
|
|
105
|
-
#
|
|
202
|
+
# If auth is not required, mark as authenticated immediately
|
|
203
|
+
self.authenticated_connections[websocket] = not requires_auth
|
|
204
|
+
|
|
205
|
+
# Send connection confirmation with auth requirement info
|
|
106
206
|
await websocket.send_text(
|
|
107
207
|
json.dumps(
|
|
108
208
|
{
|
|
109
209
|
"event": "connected",
|
|
110
|
-
"message":
|
|
210
|
+
"message": (
|
|
211
|
+
"Connected to workflow events. Please authenticate to continue."
|
|
212
|
+
if requires_auth
|
|
213
|
+
else "Connected to workflow events. Authentication not required."
|
|
214
|
+
),
|
|
215
|
+
"requires_auth": requires_auth,
|
|
111
216
|
}
|
|
112
217
|
)
|
|
113
218
|
)
|
|
114
219
|
|
|
220
|
+
async def authenticate_websocket(self, websocket: WebSocket):
|
|
221
|
+
"""Mark a WebSocket connection as authenticated"""
|
|
222
|
+
self.authenticated_connections[websocket] = True
|
|
223
|
+
logger.debug("WebSocket authenticated")
|
|
224
|
+
|
|
225
|
+
# Send authentication confirmation
|
|
226
|
+
await websocket.send_text(
|
|
227
|
+
json.dumps(
|
|
228
|
+
{
|
|
229
|
+
"event": "authenticated",
|
|
230
|
+
"message": "Authentication successful. You can now send commands.",
|
|
231
|
+
}
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def is_authenticated(self, websocket: WebSocket) -> bool:
|
|
236
|
+
"""Check if a WebSocket connection is authenticated"""
|
|
237
|
+
return self.authenticated_connections.get(websocket, False)
|
|
238
|
+
|
|
115
239
|
async def register_workflow_websocket(self, run_id: str, websocket: WebSocket):
|
|
116
240
|
"""Register a workflow run with its WebSocket connection"""
|
|
117
241
|
self.active_connections[run_id] = websocket
|
|
@@ -120,9 +244,26 @@ class WebSocketManager:
|
|
|
120
244
|
async def disconnect_by_run_id(self, run_id: str):
|
|
121
245
|
"""Remove WebSocket connection by run_id"""
|
|
122
246
|
if run_id in self.active_connections:
|
|
247
|
+
websocket = self.active_connections[run_id]
|
|
123
248
|
del self.active_connections[run_id]
|
|
249
|
+
# Clean up authentication state
|
|
250
|
+
if websocket in self.authenticated_connections:
|
|
251
|
+
del self.authenticated_connections[websocket]
|
|
124
252
|
logger.debug(f"WebSocket disconnected for run_id: {run_id}")
|
|
125
253
|
|
|
254
|
+
async def disconnect_websocket(self, websocket: WebSocket):
|
|
255
|
+
"""Remove WebSocket connection and clean up all associated state"""
|
|
256
|
+
# Remove from authenticated connections
|
|
257
|
+
if websocket in self.authenticated_connections:
|
|
258
|
+
del self.authenticated_connections[websocket]
|
|
259
|
+
|
|
260
|
+
# Remove from active connections
|
|
261
|
+
runs_to_remove = [run_id for run_id, ws in self.active_connections.items() if ws == websocket]
|
|
262
|
+
for run_id in runs_to_remove:
|
|
263
|
+
del self.active_connections[run_id]
|
|
264
|
+
|
|
265
|
+
logger.debug("WebSocket disconnected and cleaned up")
|
|
266
|
+
|
|
126
267
|
async def get_websocket_for_run(self, run_id: str) -> Optional[WebSocket]:
|
|
127
268
|
"""Get WebSocket connection for a workflow run"""
|
|
128
269
|
return self.active_connections.get(run_id)
|
|
@@ -143,6 +284,7 @@ async def agent_response_streamer(
|
|
|
143
284
|
audio: Optional[List[Audio]] = None,
|
|
144
285
|
videos: Optional[List[Video]] = None,
|
|
145
286
|
files: Optional[List[FileMedia]] = None,
|
|
287
|
+
**kwargs: Any,
|
|
146
288
|
) -> AsyncGenerator:
|
|
147
289
|
try:
|
|
148
290
|
run_response = agent.arun(
|
|
@@ -154,11 +296,19 @@ async def agent_response_streamer(
|
|
|
154
296
|
videos=videos,
|
|
155
297
|
files=files,
|
|
156
298
|
stream=True,
|
|
157
|
-
|
|
299
|
+
stream_events=True,
|
|
300
|
+
**kwargs,
|
|
158
301
|
)
|
|
159
302
|
async for run_response_chunk in run_response:
|
|
160
|
-
yield format_sse_event(run_response_chunk
|
|
161
|
-
|
|
303
|
+
yield format_sse_event(run_response_chunk) # type: ignore
|
|
304
|
+
except (InputCheckError, OutputCheckError) as e:
|
|
305
|
+
error_response = RunErrorEvent(
|
|
306
|
+
content=str(e),
|
|
307
|
+
error_type=e.type,
|
|
308
|
+
error_id=e.error_id,
|
|
309
|
+
additional_data=e.additional_data,
|
|
310
|
+
)
|
|
311
|
+
yield format_sse_event(error_response)
|
|
162
312
|
except Exception as e:
|
|
163
313
|
import traceback
|
|
164
314
|
|
|
@@ -166,7 +316,7 @@ async def agent_response_streamer(
|
|
|
166
316
|
error_response = RunErrorEvent(
|
|
167
317
|
content=str(e),
|
|
168
318
|
)
|
|
169
|
-
yield format_sse_event(error_response
|
|
319
|
+
yield format_sse_event(error_response)
|
|
170
320
|
|
|
171
321
|
|
|
172
322
|
async def agent_continue_response_streamer(
|
|
@@ -183,10 +333,18 @@ async def agent_continue_response_streamer(
|
|
|
183
333
|
session_id=session_id,
|
|
184
334
|
user_id=user_id,
|
|
185
335
|
stream=True,
|
|
186
|
-
|
|
336
|
+
stream_events=True,
|
|
187
337
|
)
|
|
188
338
|
async for run_response_chunk in continue_response:
|
|
189
|
-
yield format_sse_event(run_response_chunk
|
|
339
|
+
yield format_sse_event(run_response_chunk) # type: ignore
|
|
340
|
+
except (InputCheckError, OutputCheckError) as e:
|
|
341
|
+
error_response = RunErrorEvent(
|
|
342
|
+
content=str(e),
|
|
343
|
+
error_type=e.type,
|
|
344
|
+
error_id=e.error_id,
|
|
345
|
+
additional_data=e.additional_data,
|
|
346
|
+
)
|
|
347
|
+
yield format_sse_event(error_response)
|
|
190
348
|
|
|
191
349
|
except Exception as e:
|
|
192
350
|
import traceback
|
|
@@ -194,8 +352,10 @@ async def agent_continue_response_streamer(
|
|
|
194
352
|
traceback.print_exc(limit=3)
|
|
195
353
|
error_response = RunErrorEvent(
|
|
196
354
|
content=str(e),
|
|
355
|
+
error_type=e.type if hasattr(e, "type") else None,
|
|
356
|
+
error_id=e.error_id if hasattr(e, "error_id") else None,
|
|
197
357
|
)
|
|
198
|
-
yield format_sse_event(error_response
|
|
358
|
+
yield format_sse_event(error_response)
|
|
199
359
|
return
|
|
200
360
|
|
|
201
361
|
|
|
@@ -208,6 +368,7 @@ async def team_response_streamer(
|
|
|
208
368
|
audio: Optional[List[Audio]] = None,
|
|
209
369
|
videos: Optional[List[Video]] = None,
|
|
210
370
|
files: Optional[List[FileMedia]] = None,
|
|
371
|
+
**kwargs: Any,
|
|
211
372
|
) -> AsyncGenerator:
|
|
212
373
|
"""Run the given team asynchronously and yield its response"""
|
|
213
374
|
try:
|
|
@@ -220,10 +381,19 @@ async def team_response_streamer(
|
|
|
220
381
|
videos=videos,
|
|
221
382
|
files=files,
|
|
222
383
|
stream=True,
|
|
223
|
-
|
|
384
|
+
stream_events=True,
|
|
385
|
+
**kwargs,
|
|
224
386
|
)
|
|
225
387
|
async for run_response_chunk in run_response:
|
|
226
|
-
yield format_sse_event(run_response_chunk
|
|
388
|
+
yield format_sse_event(run_response_chunk) # type: ignore
|
|
389
|
+
except (InputCheckError, OutputCheckError) as e:
|
|
390
|
+
error_response = TeamRunErrorEvent(
|
|
391
|
+
content=str(e),
|
|
392
|
+
error_type=e.type,
|
|
393
|
+
error_id=e.error_id,
|
|
394
|
+
additional_data=e.additional_data,
|
|
395
|
+
)
|
|
396
|
+
yield format_sse_event(error_response)
|
|
227
397
|
|
|
228
398
|
except Exception as e:
|
|
229
399
|
import traceback
|
|
@@ -231,8 +401,10 @@ async def team_response_streamer(
|
|
|
231
401
|
traceback.print_exc()
|
|
232
402
|
error_response = TeamRunErrorEvent(
|
|
233
403
|
content=str(e),
|
|
404
|
+
error_type=e.type if hasattr(e, "type") else None,
|
|
405
|
+
error_id=e.error_id if hasattr(e, "error_id") else None,
|
|
234
406
|
)
|
|
235
|
-
yield format_sse_event(error_response
|
|
407
|
+
yield format_sse_event(error_response)
|
|
236
408
|
return
|
|
237
409
|
|
|
238
410
|
|
|
@@ -263,19 +435,42 @@ async def handle_workflow_via_websocket(websocket: WebSocket, message: dict, os:
|
|
|
263
435
|
session_id = str(uuid4())
|
|
264
436
|
|
|
265
437
|
# Execute workflow in background with streaming
|
|
266
|
-
await workflow.arun(
|
|
438
|
+
workflow_result = await workflow.arun( # type: ignore
|
|
267
439
|
input=user_message,
|
|
268
440
|
session_id=session_id,
|
|
269
441
|
user_id=user_id,
|
|
270
442
|
stream=True,
|
|
271
|
-
|
|
443
|
+
stream_events=True,
|
|
272
444
|
background=True,
|
|
273
445
|
websocket=websocket,
|
|
274
446
|
)
|
|
275
447
|
|
|
448
|
+
workflow_run_output = cast(WorkflowRunOutput, workflow_result)
|
|
449
|
+
|
|
450
|
+
await websocket_manager.register_workflow_websocket(workflow_run_output.run_id, websocket) # type: ignore
|
|
451
|
+
|
|
452
|
+
except (InputCheckError, OutputCheckError) as e:
|
|
453
|
+
await websocket.send_text(
|
|
454
|
+
json.dumps(
|
|
455
|
+
{
|
|
456
|
+
"event": "error",
|
|
457
|
+
"error": str(e),
|
|
458
|
+
"error_type": e.type,
|
|
459
|
+
"error_id": e.error_id,
|
|
460
|
+
"additional_data": e.additional_data,
|
|
461
|
+
}
|
|
462
|
+
)
|
|
463
|
+
)
|
|
276
464
|
except Exception as e:
|
|
277
465
|
logger.error(f"Error executing workflow via WebSocket: {e}")
|
|
278
|
-
|
|
466
|
+
error_payload = {
|
|
467
|
+
"event": "error",
|
|
468
|
+
"error": str(e),
|
|
469
|
+
"error_type": e.type if hasattr(e, "type") else None,
|
|
470
|
+
"error_id": e.error_id if hasattr(e, "error_id") else None,
|
|
471
|
+
}
|
|
472
|
+
error_payload = {k: v for k, v in error_payload.items() if v is not None}
|
|
473
|
+
await websocket.send_text(json.dumps(error_payload))
|
|
279
474
|
|
|
280
475
|
|
|
281
476
|
async def workflow_response_streamer(
|
|
@@ -286,17 +481,26 @@ async def workflow_response_streamer(
|
|
|
286
481
|
**kwargs: Any,
|
|
287
482
|
) -> AsyncGenerator:
|
|
288
483
|
try:
|
|
289
|
-
run_response =
|
|
484
|
+
run_response = workflow.arun(
|
|
290
485
|
input=input,
|
|
291
486
|
session_id=session_id,
|
|
292
487
|
user_id=user_id,
|
|
293
488
|
stream=True,
|
|
294
|
-
|
|
489
|
+
stream_events=True,
|
|
295
490
|
**kwargs,
|
|
296
491
|
)
|
|
297
492
|
|
|
298
493
|
async for run_response_chunk in run_response:
|
|
299
|
-
yield format_sse_event(run_response_chunk
|
|
494
|
+
yield format_sse_event(run_response_chunk) # type: ignore
|
|
495
|
+
|
|
496
|
+
except (InputCheckError, OutputCheckError) as e:
|
|
497
|
+
error_response = WorkflowErrorEvent(
|
|
498
|
+
error=str(e),
|
|
499
|
+
error_type=e.type,
|
|
500
|
+
error_id=e.error_id,
|
|
501
|
+
additional_data=e.additional_data,
|
|
502
|
+
)
|
|
503
|
+
yield format_sse_event(error_response)
|
|
300
504
|
|
|
301
505
|
except Exception as e:
|
|
302
506
|
import traceback
|
|
@@ -304,11 +508,84 @@ async def workflow_response_streamer(
|
|
|
304
508
|
traceback.print_exc()
|
|
305
509
|
error_response = WorkflowErrorEvent(
|
|
306
510
|
error=str(e),
|
|
511
|
+
error_type=e.type if hasattr(e, "type") else None,
|
|
512
|
+
error_id=e.error_id if hasattr(e, "error_id") else None,
|
|
307
513
|
)
|
|
308
|
-
yield format_sse_event(error_response
|
|
514
|
+
yield format_sse_event(error_response)
|
|
309
515
|
return
|
|
310
516
|
|
|
311
517
|
|
|
518
|
+
def get_websocket_router(
|
|
519
|
+
os: "AgentOS",
|
|
520
|
+
settings: AgnoAPISettings = AgnoAPISettings(),
|
|
521
|
+
) -> APIRouter:
|
|
522
|
+
"""
|
|
523
|
+
Create WebSocket router without HTTP authentication dependencies.
|
|
524
|
+
WebSocket endpoints handle authentication internally via message-based auth.
|
|
525
|
+
"""
|
|
526
|
+
ws_router = APIRouter()
|
|
527
|
+
|
|
528
|
+
@ws_router.websocket(
|
|
529
|
+
"/workflows/ws",
|
|
530
|
+
name="workflow_websocket",
|
|
531
|
+
)
|
|
532
|
+
async def workflow_websocket_endpoint(websocket: WebSocket):
|
|
533
|
+
"""WebSocket endpoint for receiving real-time workflow events"""
|
|
534
|
+
requires_auth = bool(settings.os_security_key)
|
|
535
|
+
await websocket_manager.connect(websocket, requires_auth=requires_auth)
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
while True:
|
|
539
|
+
data = await websocket.receive_text()
|
|
540
|
+
message = json.loads(data)
|
|
541
|
+
action = message.get("action")
|
|
542
|
+
|
|
543
|
+
# Handle authentication first
|
|
544
|
+
if action == "authenticate":
|
|
545
|
+
token = message.get("token")
|
|
546
|
+
if not token:
|
|
547
|
+
await websocket.send_text(json.dumps({"event": "auth_error", "error": "Token is required"}))
|
|
548
|
+
continue
|
|
549
|
+
|
|
550
|
+
if validate_websocket_token(token, settings):
|
|
551
|
+
await websocket_manager.authenticate_websocket(websocket)
|
|
552
|
+
else:
|
|
553
|
+
await websocket.send_text(json.dumps({"event": "auth_error", "error": "Invalid token"}))
|
|
554
|
+
continue
|
|
555
|
+
|
|
556
|
+
# Check authentication for all other actions (only when required)
|
|
557
|
+
elif requires_auth and not websocket_manager.is_authenticated(websocket):
|
|
558
|
+
await websocket.send_text(
|
|
559
|
+
json.dumps(
|
|
560
|
+
{
|
|
561
|
+
"event": "auth_required",
|
|
562
|
+
"error": "Authentication required. Send authenticate action with valid token.",
|
|
563
|
+
}
|
|
564
|
+
)
|
|
565
|
+
)
|
|
566
|
+
continue
|
|
567
|
+
|
|
568
|
+
# Handle authenticated actions
|
|
569
|
+
elif action == "ping":
|
|
570
|
+
await websocket.send_text(json.dumps({"event": "pong"}))
|
|
571
|
+
|
|
572
|
+
elif action == "start-workflow":
|
|
573
|
+
# Handle workflow execution directly via WebSocket
|
|
574
|
+
await handle_workflow_via_websocket(websocket, message, os)
|
|
575
|
+
|
|
576
|
+
else:
|
|
577
|
+
await websocket.send_text(json.dumps({"event": "error", "error": f"Unknown action: {action}"}))
|
|
578
|
+
|
|
579
|
+
except Exception as e:
|
|
580
|
+
if "1012" not in str(e) and "1001" not in str(e):
|
|
581
|
+
logger.error(f"WebSocket error: {e}")
|
|
582
|
+
finally:
|
|
583
|
+
# Clean up the websocket connection
|
|
584
|
+
await websocket_manager.disconnect_websocket(websocket)
|
|
585
|
+
|
|
586
|
+
return ws_router
|
|
587
|
+
|
|
588
|
+
|
|
312
589
|
def get_base_router(
|
|
313
590
|
os: "AgentOS",
|
|
314
591
|
settings: AgnoAPISettings = AgnoAPISettings(),
|
|
@@ -321,7 +598,6 @@ def get_base_router(
|
|
|
321
598
|
- Agent management and execution
|
|
322
599
|
- Team collaboration and coordination
|
|
323
600
|
- Workflow automation and orchestration
|
|
324
|
-
- Real-time WebSocket communications
|
|
325
601
|
|
|
326
602
|
All endpoints include detailed documentation, examples, and proper error handling.
|
|
327
603
|
"""
|
|
@@ -337,24 +613,6 @@ def get_base_router(
|
|
|
337
613
|
)
|
|
338
614
|
|
|
339
615
|
# -- Main Routes ---
|
|
340
|
-
|
|
341
|
-
@router.get(
|
|
342
|
-
"/health",
|
|
343
|
-
tags=["Core"],
|
|
344
|
-
operation_id="health_check",
|
|
345
|
-
summary="Health Check",
|
|
346
|
-
description="Check the health status of the AgentOS API. Returns a simple status indicator.",
|
|
347
|
-
response_model=HealthResponse,
|
|
348
|
-
responses={
|
|
349
|
-
200: {
|
|
350
|
-
"description": "API is healthy and operational",
|
|
351
|
-
"content": {"application/json": {"example": {"status": "ok"}}},
|
|
352
|
-
}
|
|
353
|
-
},
|
|
354
|
-
)
|
|
355
|
-
async def health_check() -> HealthResponse:
|
|
356
|
-
return HealthResponse(status="ok")
|
|
357
|
-
|
|
358
616
|
@router.get(
|
|
359
617
|
"/config",
|
|
360
618
|
response_model=ConfigResponse,
|
|
@@ -375,7 +633,7 @@ def get_base_router(
|
|
|
375
633
|
"content": {
|
|
376
634
|
"application/json": {
|
|
377
635
|
"example": {
|
|
378
|
-
"
|
|
636
|
+
"id": "demo",
|
|
379
637
|
"description": "Example AgentOS configuration",
|
|
380
638
|
"available_models": [],
|
|
381
639
|
"databases": ["9c884dc4-9066-448c-9074-ef49ec7eb73c"],
|
|
@@ -437,10 +695,10 @@ def get_base_router(
|
|
|
437
695
|
)
|
|
438
696
|
async def config() -> ConfigResponse:
|
|
439
697
|
return ConfigResponse(
|
|
440
|
-
os_id=os.
|
|
698
|
+
os_id=os.id or "Unnamed OS",
|
|
441
699
|
description=os.description,
|
|
442
700
|
available_models=os.config.available_models if os.config else [],
|
|
443
|
-
databases=
|
|
701
|
+
databases=list({db.id for db_id, dbs in os.dbs.items() for db in dbs}),
|
|
444
702
|
chat=os.config.chat if os.config else None,
|
|
445
703
|
session=os._get_session_config(),
|
|
446
704
|
memory=os._get_memory_config(),
|
|
@@ -451,7 +709,7 @@ def get_base_router(
|
|
|
451
709
|
teams=[TeamSummaryResponse.from_team(team) for team in os.teams] if os.teams else [],
|
|
452
710
|
workflows=[WorkflowSummaryResponse.from_workflow(w) for w in os.workflows] if os.workflows else [],
|
|
453
711
|
interfaces=[
|
|
454
|
-
InterfaceResponse(type=interface.type, version=interface.version, route=interface.
|
|
712
|
+
InterfaceResponse(type=interface.type, version=interface.version, route=interface.prefix)
|
|
455
713
|
for interface in os.interfaces
|
|
456
714
|
],
|
|
457
715
|
)
|
|
@@ -524,7 +782,7 @@ def get_base_router(
|
|
|
524
782
|
"content": {
|
|
525
783
|
"text/event-stream": {
|
|
526
784
|
"examples": {
|
|
527
|
-
"
|
|
785
|
+
"event_stream": {
|
|
528
786
|
"summary": "Example event stream response",
|
|
529
787
|
"value": 'event: RunStarted\ndata: {"content": "Hello!", "run_id": "123..."}\n\n',
|
|
530
788
|
}
|
|
@@ -538,12 +796,39 @@ def get_base_router(
|
|
|
538
796
|
)
|
|
539
797
|
async def create_agent_run(
|
|
540
798
|
agent_id: str,
|
|
799
|
+
request: Request,
|
|
541
800
|
message: str = Form(...),
|
|
542
801
|
stream: bool = Form(False),
|
|
543
802
|
session_id: Optional[str] = Form(None),
|
|
544
803
|
user_id: Optional[str] = Form(None),
|
|
545
804
|
files: Optional[List[UploadFile]] = File(None),
|
|
546
805
|
):
|
|
806
|
+
kwargs = await _get_request_kwargs(request, create_agent_run)
|
|
807
|
+
|
|
808
|
+
if hasattr(request.state, "user_id"):
|
|
809
|
+
if user_id:
|
|
810
|
+
log_warning("User ID parameter passed in both request state and kwargs, using request state")
|
|
811
|
+
user_id = request.state.user_id
|
|
812
|
+
if hasattr(request.state, "session_id"):
|
|
813
|
+
if session_id:
|
|
814
|
+
log_warning("Session ID parameter passed in both request state and kwargs, using request state")
|
|
815
|
+
session_id = request.state.session_id
|
|
816
|
+
if hasattr(request.state, "session_state"):
|
|
817
|
+
session_state = request.state.session_state
|
|
818
|
+
if "session_state" in kwargs:
|
|
819
|
+
log_warning("Session state parameter passed in both request state and kwargs, using request state")
|
|
820
|
+
kwargs["session_state"] = session_state
|
|
821
|
+
if hasattr(request.state, "dependencies"):
|
|
822
|
+
dependencies = request.state.dependencies
|
|
823
|
+
if "dependencies" in kwargs:
|
|
824
|
+
log_warning("Dependencies parameter passed in both request state and kwargs, using request state")
|
|
825
|
+
kwargs["dependencies"] = dependencies
|
|
826
|
+
if hasattr(request.state, "metadata"):
|
|
827
|
+
metadata = request.state.metadata
|
|
828
|
+
if "metadata" in kwargs:
|
|
829
|
+
log_warning("Metadata parameter passed in both request state and kwargs, using request state")
|
|
830
|
+
kwargs["metadata"] = metadata
|
|
831
|
+
|
|
547
832
|
agent = get_agent_by_id(agent_id, os.agents)
|
|
548
833
|
if agent is None:
|
|
549
834
|
raise HTTPException(status_code=404, detail="Agent not found")
|
|
@@ -559,19 +844,39 @@ def get_base_router(
|
|
|
559
844
|
|
|
560
845
|
if files:
|
|
561
846
|
for file in files:
|
|
562
|
-
if file.content_type in [
|
|
847
|
+
if file.content_type in [
|
|
848
|
+
"image/png",
|
|
849
|
+
"image/jpeg",
|
|
850
|
+
"image/jpg",
|
|
851
|
+
"image/gif",
|
|
852
|
+
"image/webp",
|
|
853
|
+
"image/bmp",
|
|
854
|
+
"image/tiff",
|
|
855
|
+
"image/tif",
|
|
856
|
+
"image/avif",
|
|
857
|
+
]:
|
|
563
858
|
try:
|
|
564
859
|
base64_image = process_image(file)
|
|
565
860
|
base64_images.append(base64_image)
|
|
566
861
|
except Exception as e:
|
|
567
862
|
log_error(f"Error processing image {file.filename}: {e}")
|
|
568
863
|
continue
|
|
569
|
-
elif file.content_type in [
|
|
864
|
+
elif file.content_type in [
|
|
865
|
+
"audio/wav",
|
|
866
|
+
"audio/wave",
|
|
867
|
+
"audio/mp3",
|
|
868
|
+
"audio/mpeg",
|
|
869
|
+
"audio/ogg",
|
|
870
|
+
"audio/mp4",
|
|
871
|
+
"audio/m4a",
|
|
872
|
+
"audio/aac",
|
|
873
|
+
"audio/flac",
|
|
874
|
+
]:
|
|
570
875
|
try:
|
|
571
|
-
|
|
572
|
-
base64_audios.append(
|
|
876
|
+
audio = process_audio(file)
|
|
877
|
+
base64_audios.append(audio)
|
|
573
878
|
except Exception as e:
|
|
574
|
-
log_error(f"Error processing audio {file.filename}: {e}")
|
|
879
|
+
log_error(f"Error processing audio {file.filename} with content type {file.content_type}: {e}")
|
|
575
880
|
continue
|
|
576
881
|
elif file.content_type in [
|
|
577
882
|
"video/x-flv",
|
|
@@ -594,15 +899,25 @@ def get_base_router(
|
|
|
594
899
|
continue
|
|
595
900
|
elif file.content_type in [
|
|
596
901
|
"application/pdf",
|
|
597
|
-
"
|
|
902
|
+
"application/json",
|
|
903
|
+
"application/x-javascript",
|
|
598
904
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
905
|
+
"text/javascript",
|
|
906
|
+
"application/x-python",
|
|
907
|
+
"text/x-python",
|
|
599
908
|
"text/plain",
|
|
600
|
-
"
|
|
909
|
+
"text/html",
|
|
910
|
+
"text/css",
|
|
911
|
+
"text/md",
|
|
912
|
+
"text/csv",
|
|
913
|
+
"text/xml",
|
|
914
|
+
"text/rtf",
|
|
601
915
|
]:
|
|
602
916
|
# Process document files
|
|
603
917
|
try:
|
|
604
|
-
|
|
605
|
-
|
|
918
|
+
input_file = process_document(file)
|
|
919
|
+
if input_file is not None:
|
|
920
|
+
input_files.append(input_file)
|
|
606
921
|
except Exception as e:
|
|
607
922
|
log_error(f"Error processing file {file.filename}: {e}")
|
|
608
923
|
continue
|
|
@@ -620,24 +935,30 @@ def get_base_router(
|
|
|
620
935
|
audio=base64_audios if base64_audios else None,
|
|
621
936
|
videos=base64_videos if base64_videos else None,
|
|
622
937
|
files=input_files if input_files else None,
|
|
938
|
+
**kwargs,
|
|
623
939
|
),
|
|
624
940
|
media_type="text/event-stream",
|
|
625
941
|
)
|
|
626
942
|
else:
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
943
|
+
try:
|
|
944
|
+
run_response = cast(
|
|
945
|
+
RunOutput,
|
|
946
|
+
await agent.arun(
|
|
947
|
+
input=message,
|
|
948
|
+
session_id=session_id,
|
|
949
|
+
user_id=user_id,
|
|
950
|
+
images=base64_images if base64_images else None,
|
|
951
|
+
audio=base64_audios if base64_audios else None,
|
|
952
|
+
videos=base64_videos if base64_videos else None,
|
|
953
|
+
files=input_files if input_files else None,
|
|
954
|
+
stream=False,
|
|
955
|
+
**kwargs,
|
|
956
|
+
),
|
|
957
|
+
)
|
|
958
|
+
return run_response.to_dict()
|
|
959
|
+
|
|
960
|
+
except InputCheckError as e:
|
|
961
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
641
962
|
|
|
642
963
|
@router.post(
|
|
643
964
|
"/agents/{agent_id}/runs/{run_id}/cancel",
|
|
@@ -698,11 +1019,17 @@ def get_base_router(
|
|
|
698
1019
|
async def continue_agent_run(
|
|
699
1020
|
agent_id: str,
|
|
700
1021
|
run_id: str,
|
|
1022
|
+
request: Request,
|
|
701
1023
|
tools: str = Form(...), # JSON string of tools
|
|
702
1024
|
session_id: Optional[str] = Form(None),
|
|
703
1025
|
user_id: Optional[str] = Form(None),
|
|
704
1026
|
stream: bool = Form(True),
|
|
705
1027
|
):
|
|
1028
|
+
if hasattr(request.state, "user_id"):
|
|
1029
|
+
user_id = request.state.user_id
|
|
1030
|
+
if hasattr(request.state, "session_id"):
|
|
1031
|
+
session_id = request.state.session_id
|
|
1032
|
+
|
|
706
1033
|
# Parse the JSON string manually
|
|
707
1034
|
try:
|
|
708
1035
|
tools_data = json.loads(tools) if tools else None
|
|
@@ -740,17 +1067,21 @@ def get_base_router(
|
|
|
740
1067
|
media_type="text/event-stream",
|
|
741
1068
|
)
|
|
742
1069
|
else:
|
|
743
|
-
|
|
744
|
-
|
|
745
|
-
|
|
746
|
-
|
|
747
|
-
|
|
748
|
-
|
|
749
|
-
|
|
750
|
-
|
|
751
|
-
|
|
752
|
-
|
|
753
|
-
|
|
1070
|
+
try:
|
|
1071
|
+
run_response_obj = cast(
|
|
1072
|
+
RunOutput,
|
|
1073
|
+
await agent.acontinue_run(
|
|
1074
|
+
run_id=run_id, # run_id from path
|
|
1075
|
+
updated_tools=updated_tools,
|
|
1076
|
+
session_id=session_id,
|
|
1077
|
+
user_id=user_id,
|
|
1078
|
+
stream=False,
|
|
1079
|
+
),
|
|
1080
|
+
)
|
|
1081
|
+
return run_response_obj.to_dict()
|
|
1082
|
+
|
|
1083
|
+
except InputCheckError as e:
|
|
1084
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
754
1085
|
|
|
755
1086
|
@router.get(
|
|
756
1087
|
"/agents",
|
|
@@ -797,7 +1128,8 @@ def get_base_router(
|
|
|
797
1128
|
|
|
798
1129
|
agents = []
|
|
799
1130
|
for agent in os.agents:
|
|
800
|
-
|
|
1131
|
+
agent_response = await AgentResponse.from_agent(agent=agent)
|
|
1132
|
+
agents.append(agent_response)
|
|
801
1133
|
|
|
802
1134
|
return agents
|
|
803
1135
|
|
|
@@ -844,7 +1176,7 @@ def get_base_router(
|
|
|
844
1176
|
if agent is None:
|
|
845
1177
|
raise HTTPException(status_code=404, detail="Agent not found")
|
|
846
1178
|
|
|
847
|
-
return AgentResponse.from_agent(agent)
|
|
1179
|
+
return await AgentResponse.from_agent(agent)
|
|
848
1180
|
|
|
849
1181
|
# -- Team routes ---
|
|
850
1182
|
|
|
@@ -880,6 +1212,7 @@ def get_base_router(
|
|
|
880
1212
|
)
|
|
881
1213
|
async def create_team_run(
|
|
882
1214
|
team_id: str,
|
|
1215
|
+
request: Request,
|
|
883
1216
|
message: str = Form(...),
|
|
884
1217
|
stream: bool = Form(True),
|
|
885
1218
|
monitor: bool = Form(True),
|
|
@@ -887,7 +1220,34 @@ def get_base_router(
|
|
|
887
1220
|
user_id: Optional[str] = Form(None),
|
|
888
1221
|
files: Optional[List[UploadFile]] = File(None),
|
|
889
1222
|
):
|
|
890
|
-
|
|
1223
|
+
kwargs = await _get_request_kwargs(request, create_team_run)
|
|
1224
|
+
|
|
1225
|
+
if hasattr(request.state, "user_id"):
|
|
1226
|
+
if user_id:
|
|
1227
|
+
log_warning("User ID parameter passed in both request state and kwargs, using request state")
|
|
1228
|
+
user_id = request.state.user_id
|
|
1229
|
+
if hasattr(request.state, "session_id"):
|
|
1230
|
+
if session_id:
|
|
1231
|
+
log_warning("Session ID parameter passed in both request state and kwargs, using request state")
|
|
1232
|
+
session_id = request.state.session_id
|
|
1233
|
+
if hasattr(request.state, "session_state"):
|
|
1234
|
+
session_state = request.state.session_state
|
|
1235
|
+
if "session_state" in kwargs:
|
|
1236
|
+
log_warning("Session state parameter passed in both request state and kwargs, using request state")
|
|
1237
|
+
kwargs["session_state"] = session_state
|
|
1238
|
+
if hasattr(request.state, "dependencies"):
|
|
1239
|
+
dependencies = request.state.dependencies
|
|
1240
|
+
if "dependencies" in kwargs:
|
|
1241
|
+
log_warning("Dependencies parameter passed in both request state and kwargs, using request state")
|
|
1242
|
+
kwargs["dependencies"] = dependencies
|
|
1243
|
+
if hasattr(request.state, "metadata"):
|
|
1244
|
+
metadata = request.state.metadata
|
|
1245
|
+
if "metadata" in kwargs:
|
|
1246
|
+
log_warning("Metadata parameter passed in both request state and kwargs, using request state")
|
|
1247
|
+
kwargs["metadata"] = metadata
|
|
1248
|
+
|
|
1249
|
+
logger.debug(f"Creating team run: {message=} {session_id=} {monitor=} {user_id=} {team_id=} {files=} {kwargs=}")
|
|
1250
|
+
|
|
891
1251
|
team = get_team_by_id(team_id, os.teams)
|
|
892
1252
|
if team is None:
|
|
893
1253
|
raise HTTPException(status_code=404, detail="Team not found")
|
|
@@ -962,21 +1322,27 @@ def get_base_router(
|
|
|
962
1322
|
audio=base64_audios if base64_audios else None,
|
|
963
1323
|
videos=base64_videos if base64_videos else None,
|
|
964
1324
|
files=document_files if document_files else None,
|
|
1325
|
+
**kwargs,
|
|
965
1326
|
),
|
|
966
1327
|
media_type="text/event-stream",
|
|
967
1328
|
)
|
|
968
1329
|
else:
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
1330
|
+
try:
|
|
1331
|
+
run_response = await team.arun(
|
|
1332
|
+
input=message,
|
|
1333
|
+
session_id=session_id,
|
|
1334
|
+
user_id=user_id,
|
|
1335
|
+
images=base64_images if base64_images else None,
|
|
1336
|
+
audio=base64_audios if base64_audios else None,
|
|
1337
|
+
videos=base64_videos if base64_videos else None,
|
|
1338
|
+
files=document_files if document_files else None,
|
|
1339
|
+
stream=False,
|
|
1340
|
+
**kwargs,
|
|
1341
|
+
)
|
|
1342
|
+
return run_response.to_dict()
|
|
1343
|
+
|
|
1344
|
+
except InputCheckError as e:
|
|
1345
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
980
1346
|
|
|
981
1347
|
@router.post(
|
|
982
1348
|
"/teams/{team_id}/runs/{run_id}/cancel",
|
|
@@ -1095,7 +1461,8 @@ def get_base_router(
|
|
|
1095
1461
|
|
|
1096
1462
|
teams = []
|
|
1097
1463
|
for team in os.teams:
|
|
1098
|
-
|
|
1464
|
+
team_response = await TeamResponse.from_team(team=team)
|
|
1465
|
+
teams.append(team_response)
|
|
1099
1466
|
|
|
1100
1467
|
return teams
|
|
1101
1468
|
|
|
@@ -1188,39 +1555,10 @@ def get_base_router(
|
|
|
1188
1555
|
if team is None:
|
|
1189
1556
|
raise HTTPException(status_code=404, detail="Team not found")
|
|
1190
1557
|
|
|
1191
|
-
return TeamResponse.from_team(team)
|
|
1558
|
+
return await TeamResponse.from_team(team)
|
|
1192
1559
|
|
|
1193
1560
|
# -- Workflow routes ---
|
|
1194
1561
|
|
|
1195
|
-
@router.websocket(
|
|
1196
|
-
"/workflows/ws",
|
|
1197
|
-
name="workflow_websocket",
|
|
1198
|
-
)
|
|
1199
|
-
async def workflow_websocket_endpoint(websocket: WebSocket):
|
|
1200
|
-
"""WebSocket endpoint for receiving real-time workflow events"""
|
|
1201
|
-
await websocket_manager.connect(websocket)
|
|
1202
|
-
|
|
1203
|
-
try:
|
|
1204
|
-
while True:
|
|
1205
|
-
data = await websocket.receive_text()
|
|
1206
|
-
message = json.loads(data)
|
|
1207
|
-
action = message.get("action")
|
|
1208
|
-
|
|
1209
|
-
if action == "ping":
|
|
1210
|
-
await websocket.send_text(json.dumps({"event": "pong"}))
|
|
1211
|
-
|
|
1212
|
-
elif action == "start-workflow":
|
|
1213
|
-
# Handle workflow execution directly via WebSocket
|
|
1214
|
-
await handle_workflow_via_websocket(websocket, message, os)
|
|
1215
|
-
except Exception as e:
|
|
1216
|
-
if "1012" not in str(e):
|
|
1217
|
-
logger.error(f"WebSocket error: {e}")
|
|
1218
|
-
finally:
|
|
1219
|
-
# Clean up any run_ids associated with this websocket
|
|
1220
|
-
runs_to_remove = [run_id for run_id, ws in websocket_manager.active_connections.items() if ws == websocket]
|
|
1221
|
-
for run_id in runs_to_remove:
|
|
1222
|
-
await websocket_manager.disconnect_by_run_id(run_id)
|
|
1223
|
-
|
|
1224
1562
|
@router.get(
|
|
1225
1563
|
"/workflows",
|
|
1226
1564
|
response_model=List[WorkflowSummaryResponse],
|
|
@@ -1290,7 +1628,7 @@ def get_base_router(
|
|
|
1290
1628
|
if workflow is None:
|
|
1291
1629
|
raise HTTPException(status_code=404, detail="Workflow not found")
|
|
1292
1630
|
|
|
1293
|
-
return WorkflowResponse.from_workflow(workflow)
|
|
1631
|
+
return await WorkflowResponse.from_workflow(workflow)
|
|
1294
1632
|
|
|
1295
1633
|
@router.post(
|
|
1296
1634
|
"/workflows/{workflow_id}/runs",
|
|
@@ -1328,12 +1666,38 @@ def get_base_router(
|
|
|
1328
1666
|
)
|
|
1329
1667
|
async def create_workflow_run(
|
|
1330
1668
|
workflow_id: str,
|
|
1669
|
+
request: Request,
|
|
1331
1670
|
message: str = Form(...),
|
|
1332
1671
|
stream: bool = Form(True),
|
|
1333
1672
|
session_id: Optional[str] = Form(None),
|
|
1334
1673
|
user_id: Optional[str] = Form(None),
|
|
1335
|
-
**kwargs: Any,
|
|
1336
1674
|
):
|
|
1675
|
+
kwargs = await _get_request_kwargs(request, create_workflow_run)
|
|
1676
|
+
|
|
1677
|
+
if hasattr(request.state, "user_id"):
|
|
1678
|
+
if user_id:
|
|
1679
|
+
log_warning("User ID parameter passed in both request state and kwargs, using request state")
|
|
1680
|
+
user_id = request.state.user_id
|
|
1681
|
+
if hasattr(request.state, "session_id"):
|
|
1682
|
+
if session_id:
|
|
1683
|
+
log_warning("Session ID parameter passed in both request state and kwargs, using request state")
|
|
1684
|
+
session_id = request.state.session_id
|
|
1685
|
+
if hasattr(request.state, "session_state"):
|
|
1686
|
+
session_state = request.state.session_state
|
|
1687
|
+
if "session_state" in kwargs:
|
|
1688
|
+
log_warning("Session state parameter passed in both request state and kwargs, using request state")
|
|
1689
|
+
kwargs["session_state"] = session_state
|
|
1690
|
+
if hasattr(request.state, "dependencies"):
|
|
1691
|
+
dependencies = request.state.dependencies
|
|
1692
|
+
if "dependencies" in kwargs:
|
|
1693
|
+
log_warning("Dependencies parameter passed in both request state and kwargs, using request state")
|
|
1694
|
+
kwargs["dependencies"] = dependencies
|
|
1695
|
+
if hasattr(request.state, "metadata"):
|
|
1696
|
+
metadata = request.state.metadata
|
|
1697
|
+
if "metadata" in kwargs:
|
|
1698
|
+
log_warning("Metadata parameter passed in both request state and kwargs, using request state")
|
|
1699
|
+
kwargs["metadata"] = metadata
|
|
1700
|
+
|
|
1337
1701
|
# Retrieve the workflow by ID
|
|
1338
1702
|
workflow = get_workflow_by_id(workflow_id, os.workflows)
|
|
1339
1703
|
if workflow is None:
|
|
@@ -1367,6 +1731,9 @@ def get_base_router(
|
|
|
1367
1731
|
**kwargs,
|
|
1368
1732
|
)
|
|
1369
1733
|
return run_response.to_dict()
|
|
1734
|
+
|
|
1735
|
+
except InputCheckError as e:
|
|
1736
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
1370
1737
|
except Exception as e:
|
|
1371
1738
|
# Handle unexpected runtime errors
|
|
1372
1739
|
raise HTTPException(status_code=500, detail=f"Error running workflow: {str(e)}")
|
|
@@ -1397,4 +1764,54 @@ def get_base_router(
|
|
|
1397
1764
|
|
|
1398
1765
|
return JSONResponse(content={}, status_code=200)
|
|
1399
1766
|
|
|
1767
|
+
# -- Database Migration routes ---
|
|
1768
|
+
|
|
1769
|
+
@router.post(
|
|
1770
|
+
"/databases/{db_id}/migrate",
|
|
1771
|
+
tags=["Database"],
|
|
1772
|
+
operation_id="migrate_database",
|
|
1773
|
+
summary="Migrate Database",
|
|
1774
|
+
description=(
|
|
1775
|
+
"Migrate the given database schema to the given target version. "
|
|
1776
|
+
"If a target version is not provided, the database will be migrated to the latest version. "
|
|
1777
|
+
),
|
|
1778
|
+
responses={
|
|
1779
|
+
200: {
|
|
1780
|
+
"description": "Database migrated successfully",
|
|
1781
|
+
"content": {
|
|
1782
|
+
"application/json": {
|
|
1783
|
+
"example": {"message": "Database migrated successfully to version 3.0.0"},
|
|
1784
|
+
}
|
|
1785
|
+
},
|
|
1786
|
+
},
|
|
1787
|
+
404: {"description": "Database not found", "model": NotFoundResponse},
|
|
1788
|
+
500: {"description": "Failed to migrate database", "model": InternalServerErrorResponse},
|
|
1789
|
+
},
|
|
1790
|
+
)
|
|
1791
|
+
async def migrate_database(db_id: str, target_version: Optional[str] = None):
|
|
1792
|
+
db = await get_db(os.dbs, db_id)
|
|
1793
|
+
if not db:
|
|
1794
|
+
raise HTTPException(status_code=404, detail="Database not found")
|
|
1795
|
+
|
|
1796
|
+
if target_version:
|
|
1797
|
+
|
|
1798
|
+
# Use the session table as proxy for the database schema version
|
|
1799
|
+
if isinstance(db, AsyncBaseDb):
|
|
1800
|
+
current_version = await db.get_latest_schema_version(db.session_table_name)
|
|
1801
|
+
else:
|
|
1802
|
+
current_version = db.get_latest_schema_version(db.session_table_name)
|
|
1803
|
+
|
|
1804
|
+
if version.parse(target_version) > version.parse(current_version): # type: ignore
|
|
1805
|
+
MigrationManager(db).up(target_version) # type: ignore
|
|
1806
|
+
else:
|
|
1807
|
+
MigrationManager(db).down(target_version) # type: ignore
|
|
1808
|
+
|
|
1809
|
+
# If the target version is not provided, migrate to the latest version
|
|
1810
|
+
else:
|
|
1811
|
+
MigrationManager(db).up() # type: ignore
|
|
1812
|
+
|
|
1813
|
+
return JSONResponse(
|
|
1814
|
+
content={"message": f"Database migrated successfully to version {target_version}"}, status_code=200
|
|
1815
|
+
)
|
|
1816
|
+
|
|
1400
1817
|
return router
|