agno 2.1.3__py3-none-any.whl → 2.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +1779 -577
- agno/db/async_postgres/__init__.py +3 -0
- agno/db/async_postgres/async_postgres.py +1668 -0
- agno/db/async_postgres/schemas.py +124 -0
- agno/db/async_postgres/utils.py +289 -0
- agno/db/base.py +237 -2
- agno/db/dynamo/dynamo.py +10 -8
- agno/db/dynamo/schemas.py +1 -10
- agno/db/dynamo/utils.py +2 -2
- agno/db/firestore/firestore.py +2 -2
- agno/db/firestore/utils.py +4 -2
- agno/db/gcs_json/gcs_json_db.py +2 -2
- agno/db/in_memory/in_memory_db.py +2 -2
- agno/db/json/json_db.py +2 -2
- agno/db/migrations/v1_to_v2.py +30 -13
- agno/db/mongo/mongo.py +18 -6
- agno/db/mysql/mysql.py +35 -13
- agno/db/postgres/postgres.py +29 -6
- agno/db/redis/redis.py +2 -2
- agno/db/singlestore/singlestore.py +2 -2
- agno/db/sqlite/sqlite.py +34 -12
- agno/db/sqlite/utils.py +8 -3
- agno/eval/accuracy.py +50 -43
- agno/eval/performance.py +6 -3
- agno/eval/reliability.py +6 -3
- agno/eval/utils.py +33 -16
- agno/exceptions.py +8 -2
- agno/knowledge/embedder/fastembed.py +1 -1
- agno/knowledge/knowledge.py +260 -46
- agno/knowledge/reader/pdf_reader.py +4 -6
- agno/knowledge/reader/reader_factory.py +2 -3
- agno/memory/manager.py +241 -33
- agno/models/anthropic/claude.py +37 -0
- agno/os/app.py +15 -10
- agno/os/interfaces/a2a/router.py +3 -5
- agno/os/interfaces/agui/router.py +4 -1
- agno/os/interfaces/agui/utils.py +33 -6
- agno/os/interfaces/slack/router.py +2 -4
- agno/os/mcp.py +98 -41
- agno/os/router.py +23 -0
- agno/os/routers/evals/evals.py +52 -20
- agno/os/routers/evals/utils.py +14 -14
- agno/os/routers/knowledge/knowledge.py +130 -9
- agno/os/routers/knowledge/schemas.py +57 -0
- agno/os/routers/memory/memory.py +116 -44
- agno/os/routers/metrics/metrics.py +16 -6
- agno/os/routers/session/session.py +65 -22
- agno/os/schema.py +38 -0
- agno/os/utils.py +69 -13
- agno/reasoning/anthropic.py +80 -0
- agno/reasoning/gemini.py +73 -0
- agno/reasoning/openai.py +5 -0
- agno/reasoning/vertexai.py +76 -0
- agno/session/workflow.py +69 -1
- agno/team/team.py +934 -241
- agno/tools/function.py +36 -18
- agno/tools/google_drive.py +270 -0
- agno/tools/googlesheets.py +20 -5
- agno/tools/mcp_toolbox.py +3 -3
- agno/tools/scrapegraph.py +1 -1
- agno/utils/models/claude.py +3 -1
- agno/utils/print_response/workflow.py +112 -12
- agno/utils/streamlit.py +1 -1
- agno/vectordb/base.py +22 -1
- agno/vectordb/cassandra/cassandra.py +9 -0
- agno/vectordb/chroma/chromadb.py +26 -6
- agno/vectordb/clickhouse/clickhousedb.py +9 -1
- agno/vectordb/couchbase/couchbase.py +11 -0
- agno/vectordb/lancedb/lance_db.py +20 -0
- agno/vectordb/langchaindb/langchaindb.py +11 -0
- agno/vectordb/lightrag/lightrag.py +9 -0
- agno/vectordb/llamaindex/llamaindexdb.py +15 -1
- agno/vectordb/milvus/milvus.py +23 -0
- agno/vectordb/mongodb/mongodb.py +22 -0
- agno/vectordb/pgvector/pgvector.py +19 -0
- agno/vectordb/pineconedb/pineconedb.py +35 -4
- agno/vectordb/qdrant/qdrant.py +24 -0
- agno/vectordb/singlestore/singlestore.py +25 -17
- agno/vectordb/surrealdb/surrealdb.py +18 -1
- agno/vectordb/upstashdb/upstashdb.py +26 -1
- agno/vectordb/weaviate/weaviate.py +18 -0
- agno/workflow/condition.py +29 -0
- agno/workflow/loop.py +29 -0
- agno/workflow/parallel.py +141 -113
- agno/workflow/router.py +29 -0
- agno/workflow/step.py +146 -25
- agno/workflow/steps.py +29 -0
- agno/workflow/types.py +26 -1
- agno/workflow/workflow.py +507 -22
- {agno-2.1.3.dist-info → agno-2.1.5.dist-info}/METADATA +100 -41
- {agno-2.1.3.dist-info → agno-2.1.5.dist-info}/RECORD +94 -86
- {agno-2.1.3.dist-info → agno-2.1.5.dist-info}/WHEEL +0 -0
- {agno-2.1.3.dist-info → agno-2.1.5.dist-info}/licenses/LICENSE +0 -0
- {agno-2.1.3.dist-info → agno-2.1.5.dist-info}/top_level.txt +0 -0
agno/os/mcp.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Router for MCP interface providing Model Context Protocol endpoints."""
|
|
2
2
|
|
|
3
3
|
import logging
|
|
4
|
-
from typing import TYPE_CHECKING, List, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, List, Optional, cast
|
|
5
5
|
from uuid import uuid4
|
|
6
6
|
|
|
7
7
|
from fastmcp import FastMCP
|
|
@@ -9,7 +9,7 @@ from fastmcp.server.http import (
|
|
|
9
9
|
StarletteWithLifespan,
|
|
10
10
|
)
|
|
11
11
|
|
|
12
|
-
from agno.db.base import SessionType
|
|
12
|
+
from agno.db.base import AsyncBaseDb, SessionType
|
|
13
13
|
from agno.db.schemas import UserMemory
|
|
14
14
|
from agno.os.routers.memory.schemas import (
|
|
15
15
|
UserMemorySchema,
|
|
@@ -104,14 +104,25 @@ def get_mcp_server(
|
|
|
104
104
|
sort_order: str = "desc",
|
|
105
105
|
):
|
|
106
106
|
db = get_db(os.dbs, db_id)
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
107
|
+
if isinstance(db, AsyncBaseDb):
|
|
108
|
+
db = cast(AsyncBaseDb, db)
|
|
109
|
+
sessions = await db.get_sessions(
|
|
110
|
+
session_type=SessionType.AGENT,
|
|
111
|
+
component_id=agent_id,
|
|
112
|
+
user_id=user_id,
|
|
113
|
+
sort_by=sort_by,
|
|
114
|
+
sort_order=sort_order,
|
|
115
|
+
deserialize=False,
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
sessions = db.get_sessions(
|
|
119
|
+
session_type=SessionType.AGENT,
|
|
120
|
+
component_id=agent_id,
|
|
121
|
+
user_id=user_id,
|
|
122
|
+
sort_by=sort_by,
|
|
123
|
+
sort_order=sort_order,
|
|
124
|
+
deserialize=False,
|
|
125
|
+
)
|
|
115
126
|
|
|
116
127
|
return {
|
|
117
128
|
"data": [SessionSchema.from_dict(session) for session in sessions], # type: ignore
|
|
@@ -126,14 +137,25 @@ def get_mcp_server(
|
|
|
126
137
|
sort_order: str = "desc",
|
|
127
138
|
):
|
|
128
139
|
db = get_db(os.dbs, db_id)
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
140
|
+
if isinstance(db, AsyncBaseDb):
|
|
141
|
+
db = cast(AsyncBaseDb, db)
|
|
142
|
+
sessions = await db.get_sessions(
|
|
143
|
+
session_type=SessionType.TEAM,
|
|
144
|
+
component_id=team_id,
|
|
145
|
+
user_id=user_id,
|
|
146
|
+
sort_by=sort_by,
|
|
147
|
+
sort_order=sort_order,
|
|
148
|
+
deserialize=False,
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
sessions = db.get_sessions(
|
|
152
|
+
session_type=SessionType.TEAM,
|
|
153
|
+
component_id=team_id,
|
|
154
|
+
user_id=user_id,
|
|
155
|
+
sort_by=sort_by,
|
|
156
|
+
sort_order=sort_order,
|
|
157
|
+
deserialize=False,
|
|
158
|
+
)
|
|
137
159
|
|
|
138
160
|
return {
|
|
139
161
|
"data": [SessionSchema.from_dict(session) for session in sessions], # type: ignore
|
|
@@ -148,14 +170,25 @@ def get_mcp_server(
|
|
|
148
170
|
sort_order: str = "desc",
|
|
149
171
|
):
|
|
150
172
|
db = get_db(os.dbs, db_id)
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
173
|
+
if isinstance(db, AsyncBaseDb):
|
|
174
|
+
db = cast(AsyncBaseDb, db)
|
|
175
|
+
sessions = await db.get_sessions(
|
|
176
|
+
session_type=SessionType.WORKFLOW,
|
|
177
|
+
component_id=workflow_id,
|
|
178
|
+
user_id=user_id,
|
|
179
|
+
sort_by=sort_by,
|
|
180
|
+
sort_order=sort_order,
|
|
181
|
+
deserialize=False,
|
|
182
|
+
)
|
|
183
|
+
else:
|
|
184
|
+
sessions = db.get_sessions(
|
|
185
|
+
session_type=SessionType.WORKFLOW,
|
|
186
|
+
component_id=workflow_id,
|
|
187
|
+
user_id=user_id,
|
|
188
|
+
sort_by=sort_by,
|
|
189
|
+
sort_order=sort_order,
|
|
190
|
+
deserialize=False,
|
|
191
|
+
)
|
|
159
192
|
|
|
160
193
|
return {
|
|
161
194
|
"data": [SessionSchema.from_dict(session) for session in sessions], # type: ignore
|
|
@@ -192,12 +225,21 @@ def get_mcp_server(
|
|
|
192
225
|
db_id: Optional[str] = None,
|
|
193
226
|
):
|
|
194
227
|
db = get_db(os.dbs, db_id)
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
228
|
+
if isinstance(db, AsyncBaseDb):
|
|
229
|
+
db = cast(AsyncBaseDb, db)
|
|
230
|
+
user_memories = await db.get_user_memories(
|
|
231
|
+
user_id=user_id,
|
|
232
|
+
sort_by=sort_by,
|
|
233
|
+
sort_order=sort_order,
|
|
234
|
+
deserialize=False,
|
|
235
|
+
)
|
|
236
|
+
else:
|
|
237
|
+
user_memories = db.get_user_memories(
|
|
238
|
+
user_id=user_id,
|
|
239
|
+
sort_by=sort_by,
|
|
240
|
+
sort_order=sort_order,
|
|
241
|
+
deserialize=False,
|
|
242
|
+
)
|
|
201
243
|
return {
|
|
202
244
|
"data": [UserMemorySchema.from_dict(user_memory) for user_memory in user_memories], # type: ignore
|
|
203
245
|
}
|
|
@@ -210,14 +252,25 @@ def get_mcp_server(
|
|
|
210
252
|
user_id: str,
|
|
211
253
|
) -> UserMemorySchema:
|
|
212
254
|
db = get_db(os.dbs, db_id)
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
memory=
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
255
|
+
if isinstance(db, AsyncBaseDb):
|
|
256
|
+
db = cast(AsyncBaseDb, db)
|
|
257
|
+
user_memory = await db.upsert_user_memory(
|
|
258
|
+
memory=UserMemory(
|
|
259
|
+
memory_id=memory_id,
|
|
260
|
+
memory=memory,
|
|
261
|
+
user_id=user_id,
|
|
262
|
+
),
|
|
263
|
+
deserialize=False,
|
|
264
|
+
)
|
|
265
|
+
else:
|
|
266
|
+
user_memory = db.upsert_user_memory(
|
|
267
|
+
memory=UserMemory(
|
|
268
|
+
memory_id=memory_id,
|
|
269
|
+
memory=memory,
|
|
270
|
+
user_id=user_id,
|
|
271
|
+
),
|
|
272
|
+
deserialize=False,
|
|
273
|
+
)
|
|
221
274
|
if not user_memory:
|
|
222
275
|
raise Exception("Failed to update memory")
|
|
223
276
|
|
|
@@ -229,7 +282,11 @@ def get_mcp_server(
|
|
|
229
282
|
memory_id: str,
|
|
230
283
|
) -> None:
|
|
231
284
|
db = get_db(os.dbs, db_id)
|
|
232
|
-
db
|
|
285
|
+
if isinstance(db, AsyncBaseDb):
|
|
286
|
+
db = cast(AsyncBaseDb, db)
|
|
287
|
+
await db.delete_user_memory(memory_id=memory_id)
|
|
288
|
+
else:
|
|
289
|
+
db.delete_user_memory(memory_id=memory_id)
|
|
233
290
|
|
|
234
291
|
mcp_app = mcp.http_app(path="/mcp")
|
|
235
292
|
return mcp_app
|
agno/os/router.py
CHANGED
|
@@ -92,6 +92,14 @@ async def _get_request_kwargs(request: Request, endpoint_func: Callable) -> Dict
|
|
|
92
92
|
kwargs.pop("dependencies")
|
|
93
93
|
log_warning(f"Invalid dependencies parameter couldn't be loaded: {dependencies}")
|
|
94
94
|
|
|
95
|
+
if metadata := kwargs.get("metadata"):
|
|
96
|
+
try:
|
|
97
|
+
metadata_dict = json.loads(metadata) # type: ignore
|
|
98
|
+
kwargs["metadata"] = metadata_dict
|
|
99
|
+
except json.JSONDecodeError:
|
|
100
|
+
kwargs.pop("metadata")
|
|
101
|
+
log_warning(f"Invalid metadata parameter couldn't be loaded: {metadata}")
|
|
102
|
+
|
|
95
103
|
return kwargs
|
|
96
104
|
|
|
97
105
|
|
|
@@ -769,6 +777,11 @@ def get_base_router(
|
|
|
769
777
|
if "dependencies" in kwargs:
|
|
770
778
|
log_warning("Dependencies parameter passed in both request state and kwargs, using request state")
|
|
771
779
|
kwargs["dependencies"] = dependencies
|
|
780
|
+
if hasattr(request.state, "metadata"):
|
|
781
|
+
metadata = request.state.metadata
|
|
782
|
+
if "metadata" in kwargs:
|
|
783
|
+
log_warning("Metadata parameter passed in both request state and kwargs, using request state")
|
|
784
|
+
kwargs["metadata"] = metadata
|
|
772
785
|
|
|
773
786
|
agent = get_agent_by_id(agent_id, os.agents)
|
|
774
787
|
if agent is None:
|
|
@@ -1180,6 +1193,11 @@ def get_base_router(
|
|
|
1180
1193
|
if "dependencies" in kwargs:
|
|
1181
1194
|
log_warning("Dependencies parameter passed in both request state and kwargs, using request state")
|
|
1182
1195
|
kwargs["dependencies"] = dependencies
|
|
1196
|
+
if hasattr(request.state, "metadata"):
|
|
1197
|
+
metadata = request.state.metadata
|
|
1198
|
+
if "metadata" in kwargs:
|
|
1199
|
+
log_warning("Metadata parameter passed in both request state and kwargs, using request state")
|
|
1200
|
+
kwargs["metadata"] = metadata
|
|
1183
1201
|
|
|
1184
1202
|
logger.debug(f"Creating team run: {message=} {session_id=} {monitor=} {user_id=} {team_id=} {files=} {kwargs=}")
|
|
1185
1203
|
|
|
@@ -1626,6 +1644,11 @@ def get_base_router(
|
|
|
1626
1644
|
if "dependencies" in kwargs:
|
|
1627
1645
|
log_warning("Dependencies parameter passed in both request state and kwargs, using request state")
|
|
1628
1646
|
kwargs["dependencies"] = dependencies
|
|
1647
|
+
if hasattr(request.state, "metadata"):
|
|
1648
|
+
metadata = request.state.metadata
|
|
1649
|
+
if "metadata" in kwargs:
|
|
1650
|
+
log_warning("Metadata parameter passed in both request state and kwargs, using request state")
|
|
1651
|
+
kwargs["metadata"] = metadata
|
|
1629
1652
|
|
|
1630
1653
|
# Retrieve the workflow by ID
|
|
1631
1654
|
workflow = get_workflow_by_id(workflow_id, os.workflows)
|
agno/os/routers/evals/evals.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from copy import deepcopy
|
|
3
|
-
from typing import List, Optional
|
|
3
|
+
from typing import List, Optional, Union, cast
|
|
4
4
|
|
|
5
5
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
|
6
6
|
|
|
7
7
|
from agno.agent.agent import Agent
|
|
8
|
-
from agno.db.base import BaseDb
|
|
8
|
+
from agno.db.base import AsyncBaseDb, BaseDb
|
|
9
9
|
from agno.db.schemas.evals import EvalFilterType, EvalType
|
|
10
10
|
from agno.models.utils import get_model
|
|
11
11
|
from agno.os.auth import get_authentication_dependency
|
|
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
|
|
|
34
34
|
|
|
35
35
|
|
|
36
36
|
def get_eval_router(
|
|
37
|
-
dbs: dict[str, BaseDb],
|
|
37
|
+
dbs: dict[str, Union[BaseDb, AsyncBaseDb]],
|
|
38
38
|
agents: Optional[List[Agent]] = None,
|
|
39
39
|
teams: Optional[List[Team]] = None,
|
|
40
40
|
settings: AgnoAPISettings = AgnoAPISettings(),
|
|
@@ -55,7 +55,10 @@ def get_eval_router(
|
|
|
55
55
|
|
|
56
56
|
|
|
57
57
|
def attach_routes(
|
|
58
|
-
router: APIRouter,
|
|
58
|
+
router: APIRouter,
|
|
59
|
+
dbs: dict[str, Union[BaseDb, AsyncBaseDb]],
|
|
60
|
+
agents: Optional[List[Agent]] = None,
|
|
61
|
+
teams: Optional[List[Team]] = None,
|
|
59
62
|
) -> APIRouter:
|
|
60
63
|
@router.get(
|
|
61
64
|
"/eval-runs",
|
|
@@ -114,19 +117,36 @@ def attach_routes(
|
|
|
114
117
|
db_id: Optional[str] = Query(default=None, description="The ID of the database to use"),
|
|
115
118
|
) -> PaginatedResponse[EvalSchema]:
|
|
116
119
|
db = get_db(dbs, db_id)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
120
|
+
|
|
121
|
+
if isinstance(db, AsyncBaseDb):
|
|
122
|
+
db = cast(AsyncBaseDb, db)
|
|
123
|
+
eval_runs, total_count = await db.get_eval_runs(
|
|
124
|
+
limit=limit,
|
|
125
|
+
page=page,
|
|
126
|
+
sort_by=sort_by,
|
|
127
|
+
sort_order=sort_order,
|
|
128
|
+
agent_id=agent_id,
|
|
129
|
+
team_id=team_id,
|
|
130
|
+
workflow_id=workflow_id,
|
|
131
|
+
model_id=model_id,
|
|
132
|
+
eval_type=eval_types,
|
|
133
|
+
filter_type=filter_type,
|
|
134
|
+
deserialize=False,
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
eval_runs, total_count = db.get_eval_runs( # type: ignore
|
|
138
|
+
limit=limit,
|
|
139
|
+
page=page,
|
|
140
|
+
sort_by=sort_by,
|
|
141
|
+
sort_order=sort_order,
|
|
142
|
+
agent_id=agent_id,
|
|
143
|
+
team_id=team_id,
|
|
144
|
+
workflow_id=workflow_id,
|
|
145
|
+
model_id=model_id,
|
|
146
|
+
eval_type=eval_types,
|
|
147
|
+
filter_type=filter_type,
|
|
148
|
+
deserialize=False,
|
|
149
|
+
)
|
|
130
150
|
|
|
131
151
|
return PaginatedResponse(
|
|
132
152
|
data=[EvalSchema.from_dict(eval_run) for eval_run in eval_runs], # type: ignore
|
|
@@ -180,7 +200,11 @@ def attach_routes(
|
|
|
180
200
|
db_id: Optional[str] = Query(default=None, description="The ID of the database to use"),
|
|
181
201
|
) -> EvalSchema:
|
|
182
202
|
db = get_db(dbs, db_id)
|
|
183
|
-
|
|
203
|
+
if isinstance(db, AsyncBaseDb):
|
|
204
|
+
db = cast(AsyncBaseDb, db)
|
|
205
|
+
eval_run = await db.get_eval_run(eval_run_id=eval_run_id, deserialize=False)
|
|
206
|
+
else:
|
|
207
|
+
eval_run = db.get_eval_run(eval_run_id=eval_run_id, deserialize=False)
|
|
184
208
|
if not eval_run:
|
|
185
209
|
raise HTTPException(status_code=404, detail=f"Eval run with id '{eval_run_id}' not found")
|
|
186
210
|
|
|
@@ -203,7 +227,11 @@ def attach_routes(
|
|
|
203
227
|
) -> None:
|
|
204
228
|
try:
|
|
205
229
|
db = get_db(dbs, db_id)
|
|
206
|
-
db
|
|
230
|
+
if isinstance(db, AsyncBaseDb):
|
|
231
|
+
db = cast(AsyncBaseDb, db)
|
|
232
|
+
await db.delete_eval_runs(eval_run_ids=request.eval_run_ids)
|
|
233
|
+
else:
|
|
234
|
+
db.delete_eval_runs(eval_run_ids=request.eval_run_ids)
|
|
207
235
|
except Exception as e:
|
|
208
236
|
raise HTTPException(status_code=500, detail=f"Failed to delete eval runs: {e}")
|
|
209
237
|
|
|
@@ -252,7 +280,11 @@ def attach_routes(
|
|
|
252
280
|
) -> EvalSchema:
|
|
253
281
|
try:
|
|
254
282
|
db = get_db(dbs, db_id)
|
|
255
|
-
|
|
283
|
+
if isinstance(db, AsyncBaseDb):
|
|
284
|
+
db = cast(AsyncBaseDb, db)
|
|
285
|
+
eval_run = await db.rename_eval_run(eval_run_id=eval_run_id, name=request.name, deserialize=False)
|
|
286
|
+
else:
|
|
287
|
+
eval_run = db.rename_eval_run(eval_run_id=eval_run_id, name=request.name, deserialize=False)
|
|
256
288
|
except Exception as e:
|
|
257
289
|
raise HTTPException(status_code=500, detail=f"Failed to rename eval run: {e}")
|
|
258
290
|
|
agno/os/routers/evals/utils.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
from typing import Optional
|
|
1
|
+
from typing import Optional, Union
|
|
2
2
|
|
|
3
3
|
from fastapi import HTTPException
|
|
4
4
|
|
|
5
5
|
from agno.agent.agent import Agent
|
|
6
|
-
from agno.db.base import BaseDb
|
|
6
|
+
from agno.db.base import AsyncBaseDb, BaseDb
|
|
7
7
|
from agno.eval.accuracy import AccuracyEval
|
|
8
8
|
from agno.eval.performance import PerformanceEval
|
|
9
9
|
from agno.eval.reliability import ReliabilityEval
|
|
@@ -14,7 +14,7 @@ from agno.team.team import Team
|
|
|
14
14
|
|
|
15
15
|
async def run_accuracy_eval(
|
|
16
16
|
eval_run_input: EvalRunInput,
|
|
17
|
-
db: BaseDb,
|
|
17
|
+
db: Union[BaseDb, AsyncBaseDb],
|
|
18
18
|
agent: Optional[Agent] = None,
|
|
19
19
|
team: Optional[Team] = None,
|
|
20
20
|
default_model: Optional[Model] = None,
|
|
@@ -35,7 +35,7 @@ async def run_accuracy_eval(
|
|
|
35
35
|
name=eval_run_input.name,
|
|
36
36
|
)
|
|
37
37
|
|
|
38
|
-
result =
|
|
38
|
+
result = accuracy_eval.run(print_results=False, print_summary=False)
|
|
39
39
|
if not result:
|
|
40
40
|
raise HTTPException(status_code=500, detail="Failed to run accuracy evaluation")
|
|
41
41
|
|
|
@@ -52,7 +52,7 @@ async def run_accuracy_eval(
|
|
|
52
52
|
|
|
53
53
|
async def run_performance_eval(
|
|
54
54
|
eval_run_input: EvalRunInput,
|
|
55
|
-
db: BaseDb,
|
|
55
|
+
db: Union[BaseDb, AsyncBaseDb],
|
|
56
56
|
agent: Optional[Agent] = None,
|
|
57
57
|
team: Optional[Team] = None,
|
|
58
58
|
default_model: Optional[Model] = None,
|
|
@@ -60,16 +60,16 @@ async def run_performance_eval(
|
|
|
60
60
|
"""Run a performance evaluation for the given agent or team"""
|
|
61
61
|
if agent:
|
|
62
62
|
|
|
63
|
-
|
|
64
|
-
return
|
|
63
|
+
def run_component(): # type: ignore
|
|
64
|
+
return agent.run(eval_run_input.input)
|
|
65
65
|
|
|
66
66
|
model_id = agent.model.id if agent and agent.model else None
|
|
67
67
|
model_provider = agent.model.provider if agent and agent.model else None
|
|
68
68
|
|
|
69
69
|
elif team:
|
|
70
70
|
|
|
71
|
-
|
|
72
|
-
return
|
|
71
|
+
def run_component():
|
|
72
|
+
return team.run(eval_run_input.input)
|
|
73
73
|
|
|
74
74
|
model_id = team.model.id if team and team.model else None
|
|
75
75
|
model_provider = team.model.provider if team and team.model else None
|
|
@@ -85,7 +85,7 @@ async def run_performance_eval(
|
|
|
85
85
|
model_id=model_id,
|
|
86
86
|
model_provider=model_provider,
|
|
87
87
|
)
|
|
88
|
-
result =
|
|
88
|
+
result = performance_eval.run(print_results=False, print_summary=False)
|
|
89
89
|
if not result:
|
|
90
90
|
raise HTTPException(status_code=500, detail="Failed to run performance evaluation")
|
|
91
91
|
|
|
@@ -109,7 +109,7 @@ async def run_performance_eval(
|
|
|
109
109
|
|
|
110
110
|
async def run_reliability_eval(
|
|
111
111
|
eval_run_input: EvalRunInput,
|
|
112
|
-
db: BaseDb,
|
|
112
|
+
db: Union[BaseDb, AsyncBaseDb],
|
|
113
113
|
agent: Optional[Agent] = None,
|
|
114
114
|
team: Optional[Team] = None,
|
|
115
115
|
default_model: Optional[Model] = None,
|
|
@@ -119,7 +119,7 @@ async def run_reliability_eval(
|
|
|
119
119
|
raise HTTPException(status_code=400, detail="expected_tool_calls is required for reliability evaluations")
|
|
120
120
|
|
|
121
121
|
if agent:
|
|
122
|
-
agent_response =
|
|
122
|
+
agent_response = agent.run(eval_run_input.input)
|
|
123
123
|
reliability_eval = ReliabilityEval(
|
|
124
124
|
db=db,
|
|
125
125
|
name=eval_run_input.name,
|
|
@@ -130,7 +130,7 @@ async def run_reliability_eval(
|
|
|
130
130
|
model_provider = agent.model.provider if agent and agent.model else None
|
|
131
131
|
|
|
132
132
|
elif team:
|
|
133
|
-
team_response =
|
|
133
|
+
team_response = team.run(eval_run_input.input)
|
|
134
134
|
reliability_eval = ReliabilityEval(
|
|
135
135
|
db=db,
|
|
136
136
|
name=eval_run_input.name,
|
|
@@ -140,7 +140,7 @@ async def run_reliability_eval(
|
|
|
140
140
|
model_id = team.model.id if team and team.model else None
|
|
141
141
|
model_provider = team.model.provider if team and team.model else None
|
|
142
142
|
|
|
143
|
-
result =
|
|
143
|
+
result = reliability_eval.run(print_results=False)
|
|
144
144
|
if not result:
|
|
145
145
|
raise HTTPException(status_code=500, detail="Failed to run reliability evaluation")
|
|
146
146
|
|