agno 2.3.19__py3-none-any.whl → 2.3.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +2466 -2048
- agno/db/dynamo/utils.py +26 -3
- agno/db/firestore/utils.py +25 -10
- agno/db/gcs_json/utils.py +14 -2
- agno/db/in_memory/utils.py +14 -2
- agno/db/json/utils.py +14 -2
- agno/db/mysql/utils.py +13 -3
- agno/db/postgres/utils.py +13 -3
- agno/db/redis/utils.py +26 -10
- agno/db/schemas/memory.py +15 -19
- agno/db/singlestore/utils.py +13 -3
- agno/db/sqlite/utils.py +15 -3
- agno/db/utils.py +22 -0
- agno/eval/agent_as_judge.py +24 -14
- agno/knowledge/embedder/mistral.py +1 -1
- agno/models/litellm/chat.py +6 -0
- agno/os/routers/evals/evals.py +0 -9
- agno/os/routers/evals/utils.py +6 -6
- agno/os/routers/knowledge/schemas.py +1 -1
- agno/os/routers/memory/schemas.py +14 -1
- agno/os/routers/metrics/schemas.py +1 -1
- agno/os/schema.py +11 -9
- agno/run/__init__.py +2 -4
- agno/run/agent.py +19 -19
- agno/run/cancel.py +65 -52
- agno/run/cancellation_management/__init__.py +9 -0
- agno/run/cancellation_management/base.py +78 -0
- agno/run/cancellation_management/in_memory_cancellation_manager.py +100 -0
- agno/run/cancellation_management/redis_cancellation_manager.py +236 -0
- agno/run/team.py +19 -19
- agno/team/team.py +1217 -1136
- agno/utils/response.py +1 -13
- agno/vectordb/weaviate/__init__.py +1 -1
- agno/workflow/workflow.py +23 -16
- {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/METADATA +60 -129
- {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/RECORD +39 -35
- {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/WHEEL +0 -0
- {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/licenses/LICENSE +0 -0
- {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/top_level.txt +0 -0
agno/os/routers/evals/utils.py
CHANGED
|
@@ -68,15 +68,11 @@ async def run_agent_as_judge_eval(
|
|
|
68
68
|
if agent:
|
|
69
69
|
agent_response = await agent.arun(eval_run_input.input, stream=False)
|
|
70
70
|
output = str(agent_response.content) if agent_response.content else ""
|
|
71
|
-
model_id = agent.model.id if agent and agent.model else None
|
|
72
|
-
model_provider = agent.model.provider if agent and agent.model else None
|
|
73
71
|
agent_id = agent.id
|
|
74
72
|
team_id = None
|
|
75
73
|
elif team:
|
|
76
74
|
team_response = await team.arun(eval_run_input.input, stream=False)
|
|
77
75
|
output = str(team_response.content) if team_response.content else ""
|
|
78
|
-
model_id = team.model.id if team and team.model else None
|
|
79
|
-
model_provider = team.model.provider if team and team.model else None
|
|
80
76
|
agent_id = None
|
|
81
77
|
team_id = team.id
|
|
82
78
|
else:
|
|
@@ -98,13 +94,17 @@ async def run_agent_as_judge_eval(
|
|
|
98
94
|
if not result:
|
|
99
95
|
raise HTTPException(status_code=500, detail="Failed to run agent as judge evaluation")
|
|
100
96
|
|
|
97
|
+
# Use evaluator's model
|
|
98
|
+
eval_model_id = agent_as_judge_eval.model.id if agent_as_judge_eval.model is not None else None
|
|
99
|
+
eval_model_provider = agent_as_judge_eval.model.provider if agent_as_judge_eval.model is not None else None
|
|
100
|
+
|
|
101
101
|
eval_run = EvalSchema.from_agent_as_judge_eval(
|
|
102
102
|
agent_as_judge_eval=agent_as_judge_eval,
|
|
103
103
|
result=result,
|
|
104
104
|
agent_id=agent_id,
|
|
105
105
|
team_id=team_id,
|
|
106
|
-
model_id=
|
|
107
|
-
model_provider=
|
|
106
|
+
model_id=eval_model_id,
|
|
107
|
+
model_provider=eval_model_provider,
|
|
108
108
|
)
|
|
109
109
|
|
|
110
110
|
# Restore original model after eval
|
|
@@ -82,7 +82,7 @@ class ContentResponseSchema(BaseModel):
|
|
|
82
82
|
status=status,
|
|
83
83
|
status_message=content.get("status_message"),
|
|
84
84
|
created_at=parse_timestamp(content.get("created_at")),
|
|
85
|
-
updated_at=parse_timestamp(content.get("updated_at")),
|
|
85
|
+
updated_at=parse_timestamp(content.get("updated_at", content.get("created_at", 0))),
|
|
86
86
|
# TODO: These fields are not available in the Content class. Fix the inconsistency
|
|
87
87
|
access_count=None,
|
|
88
88
|
linked_to=None,
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from datetime import datetime, timezone
|
|
2
3
|
from typing import Any, Dict, List, Optional
|
|
3
4
|
|
|
@@ -25,12 +26,24 @@ class UserMemorySchema(BaseModel):
|
|
|
25
26
|
if memory_dict["memory"] == "":
|
|
26
27
|
return None
|
|
27
28
|
|
|
29
|
+
# Handle nested memory content (relevant for some memories migrated from v1)
|
|
30
|
+
if isinstance(memory_dict["memory"], dict):
|
|
31
|
+
if memory_dict["memory"].get("memory") is not None:
|
|
32
|
+
memory = str(memory_dict["memory"]["memory"])
|
|
33
|
+
else:
|
|
34
|
+
try:
|
|
35
|
+
memory = json.dumps(memory_dict["memory"])
|
|
36
|
+
except json.JSONDecodeError:
|
|
37
|
+
memory = str(memory_dict["memory"])
|
|
38
|
+
else:
|
|
39
|
+
memory = memory_dict["memory"]
|
|
40
|
+
|
|
28
41
|
return cls(
|
|
29
42
|
memory_id=memory_dict["memory_id"],
|
|
30
43
|
user_id=str(memory_dict["user_id"]),
|
|
31
44
|
agent_id=memory_dict.get("agent_id"),
|
|
32
45
|
team_id=memory_dict.get("team_id"),
|
|
33
|
-
memory=
|
|
46
|
+
memory=memory,
|
|
34
47
|
topics=memory_dict.get("topics", []),
|
|
35
48
|
updated_at=memory_dict["updated_at"],
|
|
36
49
|
)
|
|
@@ -35,7 +35,7 @@ class DayAggregatedMetrics(BaseModel):
|
|
|
35
35
|
team_runs_count=metrics_dict.get("team_runs_count", 0),
|
|
36
36
|
team_sessions_count=metrics_dict.get("team_sessions_count", 0),
|
|
37
37
|
token_metrics=metrics_dict.get("token_metrics", {}),
|
|
38
|
-
updated_at=metrics_dict.get("updated_at", 0),
|
|
38
|
+
updated_at=metrics_dict.get("updated_at", metrics_dict.get("created_at", 0)),
|
|
39
39
|
users_count=metrics_dict.get("users_count", 0),
|
|
40
40
|
workflow_runs_count=metrics_dict.get("workflow_runs_count", 0),
|
|
41
41
|
workflow_sessions_count=metrics_dict.get("workflow_sessions_count", 0),
|
agno/os/schema.py
CHANGED
|
@@ -193,7 +193,7 @@ class SessionSchema(BaseModel):
|
|
|
193
193
|
session_data = session.get("session_data", {}) or {}
|
|
194
194
|
|
|
195
195
|
created_at = session.get("created_at", 0)
|
|
196
|
-
updated_at = session.get("updated_at",
|
|
196
|
+
updated_at = session.get("updated_at", created_at)
|
|
197
197
|
|
|
198
198
|
# Handle created_at and updated_at as either ISO 8601 string or timestamp
|
|
199
199
|
def parse_datetime(val):
|
|
@@ -213,7 +213,7 @@ class SessionSchema(BaseModel):
|
|
|
213
213
|
return None
|
|
214
214
|
|
|
215
215
|
created_at = parse_datetime(session.get("created_at", 0))
|
|
216
|
-
updated_at = parse_datetime(session.get("updated_at",
|
|
216
|
+
updated_at = parse_datetime(session.get("updated_at", created_at))
|
|
217
217
|
return cls(
|
|
218
218
|
session_id=session.get("session_id", ""),
|
|
219
219
|
session_name=session_name,
|
|
@@ -265,6 +265,8 @@ class AgentSessionDetailSchema(BaseModel):
|
|
|
265
265
|
@classmethod
|
|
266
266
|
def from_session(cls, session: AgentSession) -> "AgentSessionDetailSchema":
|
|
267
267
|
session_name = get_session_name({**session.to_dict(), "session_type": "agent"})
|
|
268
|
+
created_at = datetime.fromtimestamp(session.created_at, tz=timezone.utc) if session.created_at else None
|
|
269
|
+
updated_at = datetime.fromtimestamp(session.updated_at, tz=timezone.utc) if session.updated_at else created_at
|
|
268
270
|
return cls(
|
|
269
271
|
user_id=session.user_id,
|
|
270
272
|
agent_session_id=session.session_id,
|
|
@@ -280,8 +282,8 @@ class AgentSessionDetailSchema(BaseModel):
|
|
|
280
282
|
metrics=session.session_data.get("session_metrics", {}) if session.session_data else None, # type: ignore
|
|
281
283
|
metadata=session.metadata,
|
|
282
284
|
chat_history=[message.to_dict() for message in session.get_chat_history()],
|
|
283
|
-
created_at=
|
|
284
|
-
updated_at=
|
|
285
|
+
created_at=created_at,
|
|
286
|
+
updated_at=updated_at,
|
|
285
287
|
)
|
|
286
288
|
|
|
287
289
|
|
|
@@ -304,7 +306,8 @@ class TeamSessionDetailSchema(BaseModel):
|
|
|
304
306
|
def from_session(cls, session: TeamSession) -> "TeamSessionDetailSchema":
|
|
305
307
|
session_dict = session.to_dict()
|
|
306
308
|
session_name = get_session_name({**session_dict, "session_type": "team"})
|
|
307
|
-
|
|
309
|
+
created_at = datetime.fromtimestamp(session.created_at, tz=timezone.utc) if session.created_at else None
|
|
310
|
+
updated_at = datetime.fromtimestamp(session.updated_at, tz=timezone.utc) if session.updated_at else created_at
|
|
308
311
|
return cls(
|
|
309
312
|
session_id=session.session_id,
|
|
310
313
|
team_id=session.team_id,
|
|
@@ -319,8 +322,8 @@ class TeamSessionDetailSchema(BaseModel):
|
|
|
319
322
|
metrics=session.session_data.get("session_metrics", {}) if session.session_data else None,
|
|
320
323
|
metadata=session.metadata,
|
|
321
324
|
chat_history=[message.to_dict() for message in session.get_chat_history()],
|
|
322
|
-
created_at=
|
|
323
|
-
updated_at=
|
|
325
|
+
created_at=created_at,
|
|
326
|
+
updated_at=updated_at,
|
|
324
327
|
)
|
|
325
328
|
|
|
326
329
|
|
|
@@ -343,7 +346,6 @@ class WorkflowSessionDetailSchema(BaseModel):
|
|
|
343
346
|
def from_session(cls, session: WorkflowSession) -> "WorkflowSessionDetailSchema":
|
|
344
347
|
session_dict = session.to_dict()
|
|
345
348
|
session_name = get_session_name({**session_dict, "session_type": "workflow"})
|
|
346
|
-
|
|
347
349
|
return cls(
|
|
348
350
|
session_id=session.session_id,
|
|
349
351
|
user_id=session.user_id,
|
|
@@ -355,7 +357,7 @@ class WorkflowSessionDetailSchema(BaseModel):
|
|
|
355
357
|
workflow_data=session.workflow_data,
|
|
356
358
|
metadata=session.metadata,
|
|
357
359
|
created_at=session.created_at,
|
|
358
|
-
updated_at=session.updated_at,
|
|
360
|
+
updated_at=session.updated_at or session.created_at,
|
|
359
361
|
)
|
|
360
362
|
|
|
361
363
|
|
agno/run/__init__.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
|
1
1
|
from agno.run.base import RunContext, RunStatus
|
|
2
|
+
from agno.run.cancel import get_cancellation_manager, set_cancellation_manager
|
|
2
3
|
|
|
3
|
-
__all__ = [
|
|
4
|
-
"RunContext",
|
|
5
|
-
"RunStatus",
|
|
6
|
-
]
|
|
4
|
+
__all__ = ["RunContext", "RunStatus", "get_cancellation_manager", "set_cancellation_manager"]
|
agno/run/agent.py
CHANGED
|
@@ -55,8 +55,11 @@ class RunInput:
|
|
|
55
55
|
return self.input_content.model_dump_json(exclude_none=True)
|
|
56
56
|
elif isinstance(self.input_content, Message):
|
|
57
57
|
return json.dumps(self.input_content.to_dict())
|
|
58
|
-
elif isinstance(self.input_content, list)
|
|
59
|
-
|
|
58
|
+
elif isinstance(self.input_content, list):
|
|
59
|
+
try:
|
|
60
|
+
return json.dumps(self.to_dict().get("input_content"))
|
|
61
|
+
except Exception:
|
|
62
|
+
return str(self.input_content)
|
|
60
63
|
else:
|
|
61
64
|
return str(self.input_content)
|
|
62
65
|
|
|
@@ -71,22 +74,15 @@ class RunInput:
|
|
|
71
74
|
result["input_content"] = self.input_content.model_dump(exclude_none=True)
|
|
72
75
|
elif isinstance(self.input_content, Message):
|
|
73
76
|
result["input_content"] = self.input_content.to_dict()
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
# Handle input_content provided as a list of dicts
|
|
84
|
-
elif (
|
|
85
|
-
isinstance(self.input_content, list) and self.input_content and isinstance(self.input_content[0], dict)
|
|
86
|
-
):
|
|
87
|
-
for content in self.input_content:
|
|
88
|
-
# Handle media input
|
|
89
|
-
if isinstance(content, dict):
|
|
77
|
+
elif isinstance(self.input_content, list):
|
|
78
|
+
serialized_items: List[Any] = []
|
|
79
|
+
for item in self.input_content:
|
|
80
|
+
if isinstance(item, Message):
|
|
81
|
+
serialized_items.append(item.to_dict())
|
|
82
|
+
elif isinstance(item, BaseModel):
|
|
83
|
+
serialized_items.append(item.model_dump(exclude_none=True))
|
|
84
|
+
elif isinstance(item, dict):
|
|
85
|
+
content = dict(item)
|
|
90
86
|
if content.get("images"):
|
|
91
87
|
content["images"] = [
|
|
92
88
|
img.to_dict() if isinstance(img, Image) else img for img in content["images"]
|
|
@@ -103,7 +99,11 @@ class RunInput:
|
|
|
103
99
|
content["files"] = [
|
|
104
100
|
file.to_dict() if isinstance(file, File) else file for file in content["files"]
|
|
105
101
|
]
|
|
106
|
-
|
|
102
|
+
serialized_items.append(content)
|
|
103
|
+
else:
|
|
104
|
+
serialized_items.append(item)
|
|
105
|
+
|
|
106
|
+
result["input_content"] = serialized_items
|
|
107
107
|
else:
|
|
108
108
|
result["input_content"] = self.input_content
|
|
109
109
|
|
agno/run/cancel.py
CHANGED
|
@@ -1,64 +1,37 @@
|
|
|
1
1
|
"""Run cancellation management."""
|
|
2
2
|
|
|
3
|
-
import threading
|
|
4
3
|
from typing import Dict
|
|
5
4
|
|
|
6
|
-
from agno.
|
|
5
|
+
from agno.run.cancellation_management.base import BaseRunCancellationManager
|
|
6
|
+
from agno.run.cancellation_management.in_memory_cancellation_manager import InMemoryRunCancellationManager
|
|
7
7
|
from agno.utils.log import logger
|
|
8
8
|
|
|
9
|
+
# Global cancellation manager instance
|
|
10
|
+
_cancellation_manager: BaseRunCancellationManager = InMemoryRunCancellationManager()
|
|
9
11
|
|
|
10
|
-
class RunCancellationManager:
|
|
11
|
-
"""Manages cancellation state for agent runs."""
|
|
12
|
-
|
|
13
|
-
def __init__(self):
|
|
14
|
-
self._cancelled_runs: Dict[str, bool] = {}
|
|
15
|
-
self._lock = threading.Lock()
|
|
16
|
-
|
|
17
|
-
def register_run(self, run_id: str) -> None:
|
|
18
|
-
"""Register a new run as not cancelled."""
|
|
19
|
-
with self._lock:
|
|
20
|
-
self._cancelled_runs[run_id] = False
|
|
21
|
-
|
|
22
|
-
def cancel_run(self, run_id: str) -> bool:
|
|
23
|
-
"""Cancel a run by marking it as cancelled.
|
|
24
|
-
|
|
25
|
-
Returns:
|
|
26
|
-
bool: True if run was found and cancelled, False if run not found.
|
|
27
|
-
"""
|
|
28
|
-
with self._lock:
|
|
29
|
-
if run_id in self._cancelled_runs:
|
|
30
|
-
self._cancelled_runs[run_id] = True
|
|
31
|
-
logger.info(f"Run {run_id} marked for cancellation")
|
|
32
|
-
return True
|
|
33
|
-
else:
|
|
34
|
-
logger.warning(f"Attempted to cancel unknown run {run_id}")
|
|
35
|
-
return False
|
|
36
|
-
|
|
37
|
-
def is_cancelled(self, run_id: str) -> bool:
|
|
38
|
-
"""Check if a run is cancelled."""
|
|
39
|
-
with self._lock:
|
|
40
|
-
return self._cancelled_runs.get(run_id, False)
|
|
41
|
-
|
|
42
|
-
def cleanup_run(self, run_id: str) -> None:
|
|
43
|
-
"""Remove a run from tracking (called when run completes)."""
|
|
44
|
-
with self._lock:
|
|
45
|
-
if run_id in self._cancelled_runs:
|
|
46
|
-
del self._cancelled_runs[run_id]
|
|
47
|
-
|
|
48
|
-
def raise_if_cancelled(self, run_id: str) -> None:
|
|
49
|
-
"""Check if a run should be cancelled and raise exception if so."""
|
|
50
|
-
if self.is_cancelled(run_id):
|
|
51
|
-
logger.info(f"Cancelling run {run_id}")
|
|
52
|
-
raise RunCancelledException(f"Run {run_id} was cancelled")
|
|
53
|
-
|
|
54
|
-
def get_active_runs(self) -> Dict[str, bool]:
|
|
55
|
-
"""Get all currently tracked runs and their cancellation status."""
|
|
56
|
-
with self._lock:
|
|
57
|
-
return self._cancelled_runs.copy()
|
|
58
12
|
|
|
13
|
+
def set_cancellation_manager(manager: BaseRunCancellationManager) -> None:
|
|
14
|
+
"""Set a custom cancellation manager.
|
|
59
15
|
|
|
60
|
-
|
|
61
|
-
|
|
16
|
+
Args:
|
|
17
|
+
manager: A BaseRunCancellationManager instance or subclass.
|
|
18
|
+
|
|
19
|
+
Example:
|
|
20
|
+
```python
|
|
21
|
+
class MyCustomManager(BaseRunCancellationManager):
|
|
22
|
+
....
|
|
23
|
+
|
|
24
|
+
set_cancellation_manager(MyCustomManager())
|
|
25
|
+
```
|
|
26
|
+
"""
|
|
27
|
+
global _cancellation_manager
|
|
28
|
+
_cancellation_manager = manager
|
|
29
|
+
logger.info(f"Cancellation manager set to {type(manager).__name__}")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_cancellation_manager() -> BaseRunCancellationManager:
|
|
33
|
+
"""Get the current cancellation manager instance."""
|
|
34
|
+
return _cancellation_manager
|
|
62
35
|
|
|
63
36
|
|
|
64
37
|
def register_run(run_id: str) -> None:
|
|
@@ -66,16 +39,56 @@ def register_run(run_id: str) -> None:
|
|
|
66
39
|
_cancellation_manager.register_run(run_id)
|
|
67
40
|
|
|
68
41
|
|
|
42
|
+
async def aregister_run(run_id: str) -> None:
|
|
43
|
+
"""Register a new run for cancellation tracking (async version)."""
|
|
44
|
+
await _cancellation_manager.aregister_run(run_id)
|
|
45
|
+
|
|
46
|
+
|
|
69
47
|
def cancel_run(run_id: str) -> bool:
|
|
70
48
|
"""Cancel a run."""
|
|
71
49
|
return _cancellation_manager.cancel_run(run_id)
|
|
72
50
|
|
|
73
51
|
|
|
52
|
+
async def acancel_run(run_id: str) -> bool:
|
|
53
|
+
"""Cancel a run (async version)."""
|
|
54
|
+
return await _cancellation_manager.acancel_run(run_id)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def is_cancelled(run_id: str) -> bool:
|
|
58
|
+
"""Check if a run is cancelled."""
|
|
59
|
+
return _cancellation_manager.is_cancelled(run_id)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
async def ais_cancelled(run_id: str) -> bool:
|
|
63
|
+
"""Check if a run is cancelled (async version)."""
|
|
64
|
+
return await _cancellation_manager.ais_cancelled(run_id)
|
|
65
|
+
|
|
66
|
+
|
|
74
67
|
def cleanup_run(run_id: str) -> None:
|
|
75
68
|
"""Clean up cancellation tracking for a completed run."""
|
|
76
69
|
_cancellation_manager.cleanup_run(run_id)
|
|
77
70
|
|
|
78
71
|
|
|
72
|
+
async def acleanup_run(run_id: str) -> None:
|
|
73
|
+
"""Clean up cancellation tracking for a completed run (async version)."""
|
|
74
|
+
await _cancellation_manager.acleanup_run(run_id)
|
|
75
|
+
|
|
76
|
+
|
|
79
77
|
def raise_if_cancelled(run_id: str) -> None:
|
|
80
78
|
"""Check if a run should be cancelled and raise exception if so."""
|
|
81
79
|
_cancellation_manager.raise_if_cancelled(run_id)
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
async def araise_if_cancelled(run_id: str) -> None:
|
|
83
|
+
"""Check if a run should be cancelled and raise exception if so (async version)."""
|
|
84
|
+
await _cancellation_manager.araise_if_cancelled(run_id)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def get_active_runs() -> Dict[str, bool]:
|
|
88
|
+
"""Get all currently tracked runs and their cancellation status."""
|
|
89
|
+
return _cancellation_manager.get_active_runs()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
async def aget_active_runs() -> Dict[str, bool]:
|
|
93
|
+
"""Get all currently tracked runs and their cancellation status (async version)."""
|
|
94
|
+
return await _cancellation_manager.aget_active_runs()
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
from agno.run.cancellation_management.base import BaseRunCancellationManager
|
|
2
|
+
from agno.run.cancellation_management.in_memory_cancellation_manager import InMemoryRunCancellationManager
|
|
3
|
+
from agno.run.cancellation_management.redis_cancellation_manager import RedisRunCancellationManager
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"BaseRunCancellationManager",
|
|
7
|
+
"InMemoryRunCancellationManager",
|
|
8
|
+
"RedisRunCancellationManager",
|
|
9
|
+
]
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Dict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseRunCancellationManager(ABC):
|
|
6
|
+
"""Manages cancellation state for agent runs.
|
|
7
|
+
|
|
8
|
+
This class can be extended to implement custom cancellation logic.
|
|
9
|
+
Use set_cancellation_manager() to replace the global instance with your own.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
@abstractmethod
|
|
13
|
+
def register_run(self, run_id: str) -> None:
|
|
14
|
+
"""Register a new run as not cancelled."""
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
async def aregister_run(self, run_id: str) -> None:
|
|
19
|
+
"""Register a new run as not cancelled (async version)."""
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
def cancel_run(self, run_id: str) -> bool:
|
|
24
|
+
"""Cancel a run by marking it as cancelled.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
bool: True if run was found and cancelled, False if run not found.
|
|
28
|
+
"""
|
|
29
|
+
pass
|
|
30
|
+
|
|
31
|
+
@abstractmethod
|
|
32
|
+
async def acancel_run(self, run_id: str) -> bool:
|
|
33
|
+
"""Cancel a run by marking it as cancelled (async version).
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
bool: True if run was found and cancelled, False if run not found.
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
@abstractmethod
|
|
41
|
+
def is_cancelled(self, run_id: str) -> bool:
|
|
42
|
+
"""Check if a run is cancelled."""
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
async def ais_cancelled(self, run_id: str) -> bool:
|
|
47
|
+
"""Check if a run is cancelled (async version)."""
|
|
48
|
+
pass
|
|
49
|
+
|
|
50
|
+
@abstractmethod
|
|
51
|
+
def cleanup_run(self, run_id: str) -> None:
|
|
52
|
+
"""Remove a run from tracking (called when run completes)."""
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
@abstractmethod
|
|
56
|
+
async def acleanup_run(self, run_id: str) -> None:
|
|
57
|
+
"""Remove a run from tracking (called when run completes) (async version)."""
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def raise_if_cancelled(self, run_id: str) -> None:
|
|
62
|
+
"""Check if a run should be cancelled and raise exception if so."""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
@abstractmethod
|
|
66
|
+
async def araise_if_cancelled(self, run_id: str) -> None:
|
|
67
|
+
"""Check if a run should be cancelled and raise exception if so (async version)."""
|
|
68
|
+
pass
|
|
69
|
+
|
|
70
|
+
@abstractmethod
|
|
71
|
+
def get_active_runs(self) -> Dict[str, bool]:
|
|
72
|
+
"""Get all currently tracked runs and their cancellation status."""
|
|
73
|
+
pass
|
|
74
|
+
|
|
75
|
+
@abstractmethod
|
|
76
|
+
async def aget_active_runs(self) -> Dict[str, bool]:
|
|
77
|
+
"""Get all currently tracked runs and their cancellation status (async version)."""
|
|
78
|
+
pass
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Run cancellation management."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import threading
|
|
5
|
+
from typing import Dict
|
|
6
|
+
|
|
7
|
+
from agno.exceptions import RunCancelledException
|
|
8
|
+
from agno.run.cancellation_management.base import BaseRunCancellationManager
|
|
9
|
+
from agno.utils.log import logger
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class InMemoryRunCancellationManager(BaseRunCancellationManager):
|
|
13
|
+
def __init__(self):
|
|
14
|
+
self._cancelled_runs: Dict[str, bool] = {}
|
|
15
|
+
self._lock = threading.Lock()
|
|
16
|
+
self._async_lock = asyncio.Lock()
|
|
17
|
+
|
|
18
|
+
def register_run(self, run_id: str) -> None:
|
|
19
|
+
"""Register a new run as not cancelled."""
|
|
20
|
+
with self._lock:
|
|
21
|
+
self._cancelled_runs[run_id] = False
|
|
22
|
+
|
|
23
|
+
async def aregister_run(self, run_id: str) -> None:
|
|
24
|
+
"""Register a new run as not cancelled (async version)."""
|
|
25
|
+
async with self._async_lock:
|
|
26
|
+
self._cancelled_runs[run_id] = False
|
|
27
|
+
|
|
28
|
+
def cancel_run(self, run_id: str) -> bool:
|
|
29
|
+
"""Cancel a run by marking it as cancelled.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
bool: True if run was found and cancelled, False if run not found.
|
|
33
|
+
"""
|
|
34
|
+
with self._lock:
|
|
35
|
+
if run_id in self._cancelled_runs:
|
|
36
|
+
self._cancelled_runs[run_id] = True
|
|
37
|
+
logger.info(f"Run {run_id} marked for cancellation")
|
|
38
|
+
return True
|
|
39
|
+
else:
|
|
40
|
+
logger.warning(f"Attempted to cancel unknown run {run_id}")
|
|
41
|
+
return False
|
|
42
|
+
|
|
43
|
+
async def acancel_run(self, run_id: str) -> bool:
|
|
44
|
+
"""Cancel a run by marking it as cancelled (async version).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
bool: True if run was found and cancelled, False if run not found.
|
|
48
|
+
"""
|
|
49
|
+
async with self._async_lock:
|
|
50
|
+
if run_id in self._cancelled_runs:
|
|
51
|
+
self._cancelled_runs[run_id] = True
|
|
52
|
+
logger.info(f"Run {run_id} marked for cancellation")
|
|
53
|
+
return True
|
|
54
|
+
else:
|
|
55
|
+
logger.warning(f"Attempted to cancel unknown run {run_id}")
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
def is_cancelled(self, run_id: str) -> bool:
|
|
59
|
+
"""Check if a run is cancelled."""
|
|
60
|
+
with self._lock:
|
|
61
|
+
return self._cancelled_runs.get(run_id, False)
|
|
62
|
+
|
|
63
|
+
async def ais_cancelled(self, run_id: str) -> bool:
|
|
64
|
+
"""Check if a run is cancelled (async version)."""
|
|
65
|
+
async with self._async_lock:
|
|
66
|
+
return self._cancelled_runs.get(run_id, False)
|
|
67
|
+
|
|
68
|
+
def cleanup_run(self, run_id: str) -> None:
|
|
69
|
+
"""Remove a run from tracking (called when run completes)."""
|
|
70
|
+
with self._lock:
|
|
71
|
+
if run_id in self._cancelled_runs:
|
|
72
|
+
del self._cancelled_runs[run_id]
|
|
73
|
+
|
|
74
|
+
async def acleanup_run(self, run_id: str) -> None:
|
|
75
|
+
"""Remove a run from tracking (called when run completes) (async version)."""
|
|
76
|
+
async with self._async_lock:
|
|
77
|
+
if run_id in self._cancelled_runs:
|
|
78
|
+
del self._cancelled_runs[run_id]
|
|
79
|
+
|
|
80
|
+
def raise_if_cancelled(self, run_id: str) -> None:
|
|
81
|
+
"""Check if a run should be cancelled and raise exception if so."""
|
|
82
|
+
if self.is_cancelled(run_id):
|
|
83
|
+
logger.info(f"Cancelling run {run_id}")
|
|
84
|
+
raise RunCancelledException(f"Run {run_id} was cancelled")
|
|
85
|
+
|
|
86
|
+
async def araise_if_cancelled(self, run_id: str) -> None:
|
|
87
|
+
"""Check if a run should be cancelled and raise exception if so (async version)."""
|
|
88
|
+
if await self.ais_cancelled(run_id):
|
|
89
|
+
logger.info(f"Cancelling run {run_id}")
|
|
90
|
+
raise RunCancelledException(f"Run {run_id} was cancelled")
|
|
91
|
+
|
|
92
|
+
def get_active_runs(self) -> Dict[str, bool]:
|
|
93
|
+
"""Get all currently tracked runs and their cancellation status."""
|
|
94
|
+
with self._lock:
|
|
95
|
+
return self._cancelled_runs.copy()
|
|
96
|
+
|
|
97
|
+
async def aget_active_runs(self) -> Dict[str, bool]:
|
|
98
|
+
"""Get all currently tracked runs and their cancellation status (async version)."""
|
|
99
|
+
async with self._async_lock:
|
|
100
|
+
return self._cancelled_runs.copy()
|