agno 2.3.19__py3-none-any.whl → 2.3.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. agno/agent/agent.py +2466 -2048
  2. agno/db/dynamo/utils.py +26 -3
  3. agno/db/firestore/utils.py +25 -10
  4. agno/db/gcs_json/utils.py +14 -2
  5. agno/db/in_memory/utils.py +14 -2
  6. agno/db/json/utils.py +14 -2
  7. agno/db/mysql/utils.py +13 -3
  8. agno/db/postgres/utils.py +13 -3
  9. agno/db/redis/utils.py +26 -10
  10. agno/db/schemas/memory.py +15 -19
  11. agno/db/singlestore/utils.py +13 -3
  12. agno/db/sqlite/utils.py +15 -3
  13. agno/db/utils.py +22 -0
  14. agno/eval/agent_as_judge.py +24 -14
  15. agno/knowledge/embedder/mistral.py +1 -1
  16. agno/models/litellm/chat.py +6 -0
  17. agno/os/routers/evals/evals.py +0 -9
  18. agno/os/routers/evals/utils.py +6 -6
  19. agno/os/routers/knowledge/schemas.py +1 -1
  20. agno/os/routers/memory/schemas.py +14 -1
  21. agno/os/routers/metrics/schemas.py +1 -1
  22. agno/os/schema.py +11 -9
  23. agno/run/__init__.py +2 -4
  24. agno/run/agent.py +19 -19
  25. agno/run/cancel.py +65 -52
  26. agno/run/cancellation_management/__init__.py +9 -0
  27. agno/run/cancellation_management/base.py +78 -0
  28. agno/run/cancellation_management/in_memory_cancellation_manager.py +100 -0
  29. agno/run/cancellation_management/redis_cancellation_manager.py +236 -0
  30. agno/run/team.py +19 -19
  31. agno/team/team.py +1217 -1136
  32. agno/utils/response.py +1 -13
  33. agno/vectordb/weaviate/__init__.py +1 -1
  34. agno/workflow/workflow.py +23 -16
  35. {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/METADATA +60 -129
  36. {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/RECORD +39 -35
  37. {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/WHEEL +0 -0
  38. {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/licenses/LICENSE +0 -0
  39. {agno-2.3.19.dist-info → agno-2.3.21.dist-info}/top_level.txt +0 -0
@@ -68,15 +68,11 @@ async def run_agent_as_judge_eval(
68
68
  if agent:
69
69
  agent_response = await agent.arun(eval_run_input.input, stream=False)
70
70
  output = str(agent_response.content) if agent_response.content else ""
71
- model_id = agent.model.id if agent and agent.model else None
72
- model_provider = agent.model.provider if agent and agent.model else None
73
71
  agent_id = agent.id
74
72
  team_id = None
75
73
  elif team:
76
74
  team_response = await team.arun(eval_run_input.input, stream=False)
77
75
  output = str(team_response.content) if team_response.content else ""
78
- model_id = team.model.id if team and team.model else None
79
- model_provider = team.model.provider if team and team.model else None
80
76
  agent_id = None
81
77
  team_id = team.id
82
78
  else:
@@ -98,13 +94,17 @@ async def run_agent_as_judge_eval(
98
94
  if not result:
99
95
  raise HTTPException(status_code=500, detail="Failed to run agent as judge evaluation")
100
96
 
97
+ # Use evaluator's model
98
+ eval_model_id = agent_as_judge_eval.model.id if agent_as_judge_eval.model is not None else None
99
+ eval_model_provider = agent_as_judge_eval.model.provider if agent_as_judge_eval.model is not None else None
100
+
101
101
  eval_run = EvalSchema.from_agent_as_judge_eval(
102
102
  agent_as_judge_eval=agent_as_judge_eval,
103
103
  result=result,
104
104
  agent_id=agent_id,
105
105
  team_id=team_id,
106
- model_id=model_id,
107
- model_provider=model_provider,
106
+ model_id=eval_model_id,
107
+ model_provider=eval_model_provider,
108
108
  )
109
109
 
110
110
  # Restore original model after eval
@@ -82,7 +82,7 @@ class ContentResponseSchema(BaseModel):
82
82
  status=status,
83
83
  status_message=content.get("status_message"),
84
84
  created_at=parse_timestamp(content.get("created_at")),
85
- updated_at=parse_timestamp(content.get("updated_at")),
85
+ updated_at=parse_timestamp(content.get("updated_at", content.get("created_at", 0))),
86
86
  # TODO: These fields are not available in the Content class. Fix the inconsistency
87
87
  access_count=None,
88
88
  linked_to=None,
@@ -1,3 +1,4 @@
1
+ import json
1
2
  from datetime import datetime, timezone
2
3
  from typing import Any, Dict, List, Optional
3
4
 
@@ -25,12 +26,24 @@ class UserMemorySchema(BaseModel):
25
26
  if memory_dict["memory"] == "":
26
27
  return None
27
28
 
29
+ # Handle nested memory content (relevant for some memories migrated from v1)
30
+ if isinstance(memory_dict["memory"], dict):
31
+ if memory_dict["memory"].get("memory") is not None:
32
+ memory = str(memory_dict["memory"]["memory"])
33
+ else:
34
+ try:
35
+ memory = json.dumps(memory_dict["memory"])
36
+ except json.JSONDecodeError:
37
+ memory = str(memory_dict["memory"])
38
+ else:
39
+ memory = memory_dict["memory"]
40
+
28
41
  return cls(
29
42
  memory_id=memory_dict["memory_id"],
30
43
  user_id=str(memory_dict["user_id"]),
31
44
  agent_id=memory_dict.get("agent_id"),
32
45
  team_id=memory_dict.get("team_id"),
33
- memory=memory_dict["memory"],
46
+ memory=memory,
34
47
  topics=memory_dict.get("topics", []),
35
48
  updated_at=memory_dict["updated_at"],
36
49
  )
@@ -35,7 +35,7 @@ class DayAggregatedMetrics(BaseModel):
35
35
  team_runs_count=metrics_dict.get("team_runs_count", 0),
36
36
  team_sessions_count=metrics_dict.get("team_sessions_count", 0),
37
37
  token_metrics=metrics_dict.get("token_metrics", {}),
38
- updated_at=metrics_dict.get("updated_at", 0),
38
+ updated_at=metrics_dict.get("updated_at", metrics_dict.get("created_at", 0)),
39
39
  users_count=metrics_dict.get("users_count", 0),
40
40
  workflow_runs_count=metrics_dict.get("workflow_runs_count", 0),
41
41
  workflow_sessions_count=metrics_dict.get("workflow_sessions_count", 0),
agno/os/schema.py CHANGED
@@ -193,7 +193,7 @@ class SessionSchema(BaseModel):
193
193
  session_data = session.get("session_data", {}) or {}
194
194
 
195
195
  created_at = session.get("created_at", 0)
196
- updated_at = session.get("updated_at", 0)
196
+ updated_at = session.get("updated_at", created_at)
197
197
 
198
198
  # Handle created_at and updated_at as either ISO 8601 string or timestamp
199
199
  def parse_datetime(val):
@@ -213,7 +213,7 @@ class SessionSchema(BaseModel):
213
213
  return None
214
214
 
215
215
  created_at = parse_datetime(session.get("created_at", 0))
216
- updated_at = parse_datetime(session.get("updated_at", 0))
216
+ updated_at = parse_datetime(session.get("updated_at", created_at))
217
217
  return cls(
218
218
  session_id=session.get("session_id", ""),
219
219
  session_name=session_name,
@@ -265,6 +265,8 @@ class AgentSessionDetailSchema(BaseModel):
265
265
  @classmethod
266
266
  def from_session(cls, session: AgentSession) -> "AgentSessionDetailSchema":
267
267
  session_name = get_session_name({**session.to_dict(), "session_type": "agent"})
268
+ created_at = datetime.fromtimestamp(session.created_at, tz=timezone.utc) if session.created_at else None
269
+ updated_at = datetime.fromtimestamp(session.updated_at, tz=timezone.utc) if session.updated_at else created_at
268
270
  return cls(
269
271
  user_id=session.user_id,
270
272
  agent_session_id=session.session_id,
@@ -280,8 +282,8 @@ class AgentSessionDetailSchema(BaseModel):
280
282
  metrics=session.session_data.get("session_metrics", {}) if session.session_data else None, # type: ignore
281
283
  metadata=session.metadata,
282
284
  chat_history=[message.to_dict() for message in session.get_chat_history()],
283
- created_at=datetime.fromtimestamp(session.created_at, tz=timezone.utc) if session.created_at else None,
284
- updated_at=datetime.fromtimestamp(session.updated_at, tz=timezone.utc) if session.updated_at else None,
285
+ created_at=created_at,
286
+ updated_at=updated_at,
285
287
  )
286
288
 
287
289
 
@@ -304,7 +306,8 @@ class TeamSessionDetailSchema(BaseModel):
304
306
  def from_session(cls, session: TeamSession) -> "TeamSessionDetailSchema":
305
307
  session_dict = session.to_dict()
306
308
  session_name = get_session_name({**session_dict, "session_type": "team"})
307
-
309
+ created_at = datetime.fromtimestamp(session.created_at, tz=timezone.utc) if session.created_at else None
310
+ updated_at = datetime.fromtimestamp(session.updated_at, tz=timezone.utc) if session.updated_at else created_at
308
311
  return cls(
309
312
  session_id=session.session_id,
310
313
  team_id=session.team_id,
@@ -319,8 +322,8 @@ class TeamSessionDetailSchema(BaseModel):
319
322
  metrics=session.session_data.get("session_metrics", {}) if session.session_data else None,
320
323
  metadata=session.metadata,
321
324
  chat_history=[message.to_dict() for message in session.get_chat_history()],
322
- created_at=datetime.fromtimestamp(session.created_at, tz=timezone.utc) if session.created_at else None,
323
- updated_at=datetime.fromtimestamp(session.updated_at, tz=timezone.utc) if session.updated_at else None,
325
+ created_at=created_at,
326
+ updated_at=updated_at,
324
327
  )
325
328
 
326
329
 
@@ -343,7 +346,6 @@ class WorkflowSessionDetailSchema(BaseModel):
343
346
  def from_session(cls, session: WorkflowSession) -> "WorkflowSessionDetailSchema":
344
347
  session_dict = session.to_dict()
345
348
  session_name = get_session_name({**session_dict, "session_type": "workflow"})
346
-
347
349
  return cls(
348
350
  session_id=session.session_id,
349
351
  user_id=session.user_id,
@@ -355,7 +357,7 @@ class WorkflowSessionDetailSchema(BaseModel):
355
357
  workflow_data=session.workflow_data,
356
358
  metadata=session.metadata,
357
359
  created_at=session.created_at,
358
- updated_at=session.updated_at,
360
+ updated_at=session.updated_at or session.created_at,
359
361
  )
360
362
 
361
363
 
agno/run/__init__.py CHANGED
@@ -1,6 +1,4 @@
1
1
  from agno.run.base import RunContext, RunStatus
2
+ from agno.run.cancel import get_cancellation_manager, set_cancellation_manager
2
3
 
3
- __all__ = [
4
- "RunContext",
5
- "RunStatus",
6
- ]
4
+ __all__ = ["RunContext", "RunStatus", "get_cancellation_manager", "set_cancellation_manager"]
agno/run/agent.py CHANGED
@@ -55,8 +55,11 @@ class RunInput:
55
55
  return self.input_content.model_dump_json(exclude_none=True)
56
56
  elif isinstance(self.input_content, Message):
57
57
  return json.dumps(self.input_content.to_dict())
58
- elif isinstance(self.input_content, list) and self.input_content and isinstance(self.input_content[0], Message):
59
- return json.dumps([m.to_dict() for m in self.input_content])
58
+ elif isinstance(self.input_content, list):
59
+ try:
60
+ return json.dumps(self.to_dict().get("input_content"))
61
+ except Exception:
62
+ return str(self.input_content)
60
63
  else:
61
64
  return str(self.input_content)
62
65
 
@@ -71,22 +74,15 @@ class RunInput:
71
74
  result["input_content"] = self.input_content.model_dump(exclude_none=True)
72
75
  elif isinstance(self.input_content, Message):
73
76
  result["input_content"] = self.input_content.to_dict()
74
-
75
- # Handle input_content provided as a list of Message objects
76
- elif (
77
- isinstance(self.input_content, list)
78
- and self.input_content
79
- and isinstance(self.input_content[0], Message)
80
- ):
81
- result["input_content"] = [m.to_dict() for m in self.input_content]
82
-
83
- # Handle input_content provided as a list of dicts
84
- elif (
85
- isinstance(self.input_content, list) and self.input_content and isinstance(self.input_content[0], dict)
86
- ):
87
- for content in self.input_content:
88
- # Handle media input
89
- if isinstance(content, dict):
77
+ elif isinstance(self.input_content, list):
78
+ serialized_items: List[Any] = []
79
+ for item in self.input_content:
80
+ if isinstance(item, Message):
81
+ serialized_items.append(item.to_dict())
82
+ elif isinstance(item, BaseModel):
83
+ serialized_items.append(item.model_dump(exclude_none=True))
84
+ elif isinstance(item, dict):
85
+ content = dict(item)
90
86
  if content.get("images"):
91
87
  content["images"] = [
92
88
  img.to_dict() if isinstance(img, Image) else img for img in content["images"]
@@ -103,7 +99,11 @@ class RunInput:
103
99
  content["files"] = [
104
100
  file.to_dict() if isinstance(file, File) else file for file in content["files"]
105
101
  ]
106
- result["input_content"] = self.input_content
102
+ serialized_items.append(content)
103
+ else:
104
+ serialized_items.append(item)
105
+
106
+ result["input_content"] = serialized_items
107
107
  else:
108
108
  result["input_content"] = self.input_content
109
109
 
agno/run/cancel.py CHANGED
@@ -1,64 +1,37 @@
1
1
  """Run cancellation management."""
2
2
 
3
- import threading
4
3
  from typing import Dict
5
4
 
6
- from agno.exceptions import RunCancelledException
5
+ from agno.run.cancellation_management.base import BaseRunCancellationManager
6
+ from agno.run.cancellation_management.in_memory_cancellation_manager import InMemoryRunCancellationManager
7
7
  from agno.utils.log import logger
8
8
 
9
+ # Global cancellation manager instance
10
+ _cancellation_manager: BaseRunCancellationManager = InMemoryRunCancellationManager()
9
11
 
10
- class RunCancellationManager:
11
- """Manages cancellation state for agent runs."""
12
-
13
- def __init__(self):
14
- self._cancelled_runs: Dict[str, bool] = {}
15
- self._lock = threading.Lock()
16
-
17
- def register_run(self, run_id: str) -> None:
18
- """Register a new run as not cancelled."""
19
- with self._lock:
20
- self._cancelled_runs[run_id] = False
21
-
22
- def cancel_run(self, run_id: str) -> bool:
23
- """Cancel a run by marking it as cancelled.
24
-
25
- Returns:
26
- bool: True if run was found and cancelled, False if run not found.
27
- """
28
- with self._lock:
29
- if run_id in self._cancelled_runs:
30
- self._cancelled_runs[run_id] = True
31
- logger.info(f"Run {run_id} marked for cancellation")
32
- return True
33
- else:
34
- logger.warning(f"Attempted to cancel unknown run {run_id}")
35
- return False
36
-
37
- def is_cancelled(self, run_id: str) -> bool:
38
- """Check if a run is cancelled."""
39
- with self._lock:
40
- return self._cancelled_runs.get(run_id, False)
41
-
42
- def cleanup_run(self, run_id: str) -> None:
43
- """Remove a run from tracking (called when run completes)."""
44
- with self._lock:
45
- if run_id in self._cancelled_runs:
46
- del self._cancelled_runs[run_id]
47
-
48
- def raise_if_cancelled(self, run_id: str) -> None:
49
- """Check if a run should be cancelled and raise exception if so."""
50
- if self.is_cancelled(run_id):
51
- logger.info(f"Cancelling run {run_id}")
52
- raise RunCancelledException(f"Run {run_id} was cancelled")
53
-
54
- def get_active_runs(self) -> Dict[str, bool]:
55
- """Get all currently tracked runs and their cancellation status."""
56
- with self._lock:
57
- return self._cancelled_runs.copy()
58
12
 
13
+ def set_cancellation_manager(manager: BaseRunCancellationManager) -> None:
14
+ """Set a custom cancellation manager.
59
15
 
60
- # Global cancellation manager instance
61
- _cancellation_manager = RunCancellationManager()
16
+ Args:
17
+ manager: A BaseRunCancellationManager instance or subclass.
18
+
19
+ Example:
20
+ ```python
21
+ class MyCustomManager(BaseRunCancellationManager):
22
+ ....
23
+
24
+ set_cancellation_manager(MyCustomManager())
25
+ ```
26
+ """
27
+ global _cancellation_manager
28
+ _cancellation_manager = manager
29
+ logger.info(f"Cancellation manager set to {type(manager).__name__}")
30
+
31
+
32
+ def get_cancellation_manager() -> BaseRunCancellationManager:
33
+ """Get the current cancellation manager instance."""
34
+ return _cancellation_manager
62
35
 
63
36
 
64
37
  def register_run(run_id: str) -> None:
@@ -66,16 +39,56 @@ def register_run(run_id: str) -> None:
66
39
  _cancellation_manager.register_run(run_id)
67
40
 
68
41
 
42
+ async def aregister_run(run_id: str) -> None:
43
+ """Register a new run for cancellation tracking (async version)."""
44
+ await _cancellation_manager.aregister_run(run_id)
45
+
46
+
69
47
  def cancel_run(run_id: str) -> bool:
70
48
  """Cancel a run."""
71
49
  return _cancellation_manager.cancel_run(run_id)
72
50
 
73
51
 
52
+ async def acancel_run(run_id: str) -> bool:
53
+ """Cancel a run (async version)."""
54
+ return await _cancellation_manager.acancel_run(run_id)
55
+
56
+
57
+ def is_cancelled(run_id: str) -> bool:
58
+ """Check if a run is cancelled."""
59
+ return _cancellation_manager.is_cancelled(run_id)
60
+
61
+
62
+ async def ais_cancelled(run_id: str) -> bool:
63
+ """Check if a run is cancelled (async version)."""
64
+ return await _cancellation_manager.ais_cancelled(run_id)
65
+
66
+
74
67
  def cleanup_run(run_id: str) -> None:
75
68
  """Clean up cancellation tracking for a completed run."""
76
69
  _cancellation_manager.cleanup_run(run_id)
77
70
 
78
71
 
72
+ async def acleanup_run(run_id: str) -> None:
73
+ """Clean up cancellation tracking for a completed run (async version)."""
74
+ await _cancellation_manager.acleanup_run(run_id)
75
+
76
+
79
77
  def raise_if_cancelled(run_id: str) -> None:
80
78
  """Check if a run should be cancelled and raise exception if so."""
81
79
  _cancellation_manager.raise_if_cancelled(run_id)
80
+
81
+
82
+ async def araise_if_cancelled(run_id: str) -> None:
83
+ """Check if a run should be cancelled and raise exception if so (async version)."""
84
+ await _cancellation_manager.araise_if_cancelled(run_id)
85
+
86
+
87
+ def get_active_runs() -> Dict[str, bool]:
88
+ """Get all currently tracked runs and their cancellation status."""
89
+ return _cancellation_manager.get_active_runs()
90
+
91
+
92
+ async def aget_active_runs() -> Dict[str, bool]:
93
+ """Get all currently tracked runs and their cancellation status (async version)."""
94
+ return await _cancellation_manager.aget_active_runs()
@@ -0,0 +1,9 @@
1
+ from agno.run.cancellation_management.base import BaseRunCancellationManager
2
+ from agno.run.cancellation_management.in_memory_cancellation_manager import InMemoryRunCancellationManager
3
+ from agno.run.cancellation_management.redis_cancellation_manager import RedisRunCancellationManager
4
+
5
+ __all__ = [
6
+ "BaseRunCancellationManager",
7
+ "InMemoryRunCancellationManager",
8
+ "RedisRunCancellationManager",
9
+ ]
@@ -0,0 +1,78 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict
3
+
4
+
5
+ class BaseRunCancellationManager(ABC):
6
+ """Manages cancellation state for agent runs.
7
+
8
+ This class can be extended to implement custom cancellation logic.
9
+ Use set_cancellation_manager() to replace the global instance with your own.
10
+ """
11
+
12
+ @abstractmethod
13
+ def register_run(self, run_id: str) -> None:
14
+ """Register a new run as not cancelled."""
15
+ pass
16
+
17
+ @abstractmethod
18
+ async def aregister_run(self, run_id: str) -> None:
19
+ """Register a new run as not cancelled (async version)."""
20
+ pass
21
+
22
+ @abstractmethod
23
+ def cancel_run(self, run_id: str) -> bool:
24
+ """Cancel a run by marking it as cancelled.
25
+
26
+ Returns:
27
+ bool: True if run was found and cancelled, False if run not found.
28
+ """
29
+ pass
30
+
31
+ @abstractmethod
32
+ async def acancel_run(self, run_id: str) -> bool:
33
+ """Cancel a run by marking it as cancelled (async version).
34
+
35
+ Returns:
36
+ bool: True if run was found and cancelled, False if run not found.
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def is_cancelled(self, run_id: str) -> bool:
42
+ """Check if a run is cancelled."""
43
+ pass
44
+
45
+ @abstractmethod
46
+ async def ais_cancelled(self, run_id: str) -> bool:
47
+ """Check if a run is cancelled (async version)."""
48
+ pass
49
+
50
+ @abstractmethod
51
+ def cleanup_run(self, run_id: str) -> None:
52
+ """Remove a run from tracking (called when run completes)."""
53
+ pass
54
+
55
+ @abstractmethod
56
+ async def acleanup_run(self, run_id: str) -> None:
57
+ """Remove a run from tracking (called when run completes) (async version)."""
58
+ pass
59
+
60
+ @abstractmethod
61
+ def raise_if_cancelled(self, run_id: str) -> None:
62
+ """Check if a run should be cancelled and raise exception if so."""
63
+ pass
64
+
65
+ @abstractmethod
66
+ async def araise_if_cancelled(self, run_id: str) -> None:
67
+ """Check if a run should be cancelled and raise exception if so (async version)."""
68
+ pass
69
+
70
+ @abstractmethod
71
+ def get_active_runs(self) -> Dict[str, bool]:
72
+ """Get all currently tracked runs and their cancellation status."""
73
+ pass
74
+
75
+ @abstractmethod
76
+ async def aget_active_runs(self) -> Dict[str, bool]:
77
+ """Get all currently tracked runs and their cancellation status (async version)."""
78
+ pass
@@ -0,0 +1,100 @@
1
+ """Run cancellation management."""
2
+
3
+ import asyncio
4
+ import threading
5
+ from typing import Dict
6
+
7
+ from agno.exceptions import RunCancelledException
8
+ from agno.run.cancellation_management.base import BaseRunCancellationManager
9
+ from agno.utils.log import logger
10
+
11
+
12
+ class InMemoryRunCancellationManager(BaseRunCancellationManager):
13
+ def __init__(self):
14
+ self._cancelled_runs: Dict[str, bool] = {}
15
+ self._lock = threading.Lock()
16
+ self._async_lock = asyncio.Lock()
17
+
18
+ def register_run(self, run_id: str) -> None:
19
+ """Register a new run as not cancelled."""
20
+ with self._lock:
21
+ self._cancelled_runs[run_id] = False
22
+
23
+ async def aregister_run(self, run_id: str) -> None:
24
+ """Register a new run as not cancelled (async version)."""
25
+ async with self._async_lock:
26
+ self._cancelled_runs[run_id] = False
27
+
28
+ def cancel_run(self, run_id: str) -> bool:
29
+ """Cancel a run by marking it as cancelled.
30
+
31
+ Returns:
32
+ bool: True if run was found and cancelled, False if run not found.
33
+ """
34
+ with self._lock:
35
+ if run_id in self._cancelled_runs:
36
+ self._cancelled_runs[run_id] = True
37
+ logger.info(f"Run {run_id} marked for cancellation")
38
+ return True
39
+ else:
40
+ logger.warning(f"Attempted to cancel unknown run {run_id}")
41
+ return False
42
+
43
+ async def acancel_run(self, run_id: str) -> bool:
44
+ """Cancel a run by marking it as cancelled (async version).
45
+
46
+ Returns:
47
+ bool: True if run was found and cancelled, False if run not found.
48
+ """
49
+ async with self._async_lock:
50
+ if run_id in self._cancelled_runs:
51
+ self._cancelled_runs[run_id] = True
52
+ logger.info(f"Run {run_id} marked for cancellation")
53
+ return True
54
+ else:
55
+ logger.warning(f"Attempted to cancel unknown run {run_id}")
56
+ return False
57
+
58
+ def is_cancelled(self, run_id: str) -> bool:
59
+ """Check if a run is cancelled."""
60
+ with self._lock:
61
+ return self._cancelled_runs.get(run_id, False)
62
+
63
+ async def ais_cancelled(self, run_id: str) -> bool:
64
+ """Check if a run is cancelled (async version)."""
65
+ async with self._async_lock:
66
+ return self._cancelled_runs.get(run_id, False)
67
+
68
+ def cleanup_run(self, run_id: str) -> None:
69
+ """Remove a run from tracking (called when run completes)."""
70
+ with self._lock:
71
+ if run_id in self._cancelled_runs:
72
+ del self._cancelled_runs[run_id]
73
+
74
+ async def acleanup_run(self, run_id: str) -> None:
75
+ """Remove a run from tracking (called when run completes) (async version)."""
76
+ async with self._async_lock:
77
+ if run_id in self._cancelled_runs:
78
+ del self._cancelled_runs[run_id]
79
+
80
+ def raise_if_cancelled(self, run_id: str) -> None:
81
+ """Check if a run should be cancelled and raise exception if so."""
82
+ if self.is_cancelled(run_id):
83
+ logger.info(f"Cancelling run {run_id}")
84
+ raise RunCancelledException(f"Run {run_id} was cancelled")
85
+
86
+ async def araise_if_cancelled(self, run_id: str) -> None:
87
+ """Check if a run should be cancelled and raise exception if so (async version)."""
88
+ if await self.ais_cancelled(run_id):
89
+ logger.info(f"Cancelling run {run_id}")
90
+ raise RunCancelledException(f"Run {run_id} was cancelled")
91
+
92
+ def get_active_runs(self) -> Dict[str, bool]:
93
+ """Get all currently tracked runs and their cancellation status."""
94
+ with self._lock:
95
+ return self._cancelled_runs.copy()
96
+
97
+ async def aget_active_runs(self) -> Dict[str, bool]:
98
+ """Get all currently tracked runs and their cancellation status (async version)."""
99
+ async with self._async_lock:
100
+ return self._cancelled_runs.copy()