agno 2.3.19__py3-none-any.whl → 2.3.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
agno/run/__init__.py CHANGED
@@ -1,6 +1,4 @@
1
1
  from agno.run.base import RunContext, RunStatus
2
+ from agno.run.cancel import get_cancellation_manager, set_cancellation_manager
2
3
 
3
- __all__ = [
4
- "RunContext",
5
- "RunStatus",
6
- ]
4
+ __all__ = ["RunContext", "RunStatus", "get_cancellation_manager", "set_cancellation_manager"]
agno/run/cancel.py CHANGED
@@ -1,64 +1,37 @@
1
1
  """Run cancellation management."""
2
2
 
3
- import threading
4
3
  from typing import Dict
5
4
 
6
- from agno.exceptions import RunCancelledException
5
+ from agno.run.cancellation_management.base import BaseRunCancellationManager
6
+ from agno.run.cancellation_management.in_memory_cancellation_manager import InMemoryRunCancellationManager
7
7
  from agno.utils.log import logger
8
8
 
9
+ # Global cancellation manager instance
10
+ _cancellation_manager: BaseRunCancellationManager = InMemoryRunCancellationManager()
9
11
 
10
- class RunCancellationManager:
11
- """Manages cancellation state for agent runs."""
12
-
13
- def __init__(self):
14
- self._cancelled_runs: Dict[str, bool] = {}
15
- self._lock = threading.Lock()
16
-
17
- def register_run(self, run_id: str) -> None:
18
- """Register a new run as not cancelled."""
19
- with self._lock:
20
- self._cancelled_runs[run_id] = False
21
-
22
- def cancel_run(self, run_id: str) -> bool:
23
- """Cancel a run by marking it as cancelled.
24
-
25
- Returns:
26
- bool: True if run was found and cancelled, False if run not found.
27
- """
28
- with self._lock:
29
- if run_id in self._cancelled_runs:
30
- self._cancelled_runs[run_id] = True
31
- logger.info(f"Run {run_id} marked for cancellation")
32
- return True
33
- else:
34
- logger.warning(f"Attempted to cancel unknown run {run_id}")
35
- return False
36
-
37
- def is_cancelled(self, run_id: str) -> bool:
38
- """Check if a run is cancelled."""
39
- with self._lock:
40
- return self._cancelled_runs.get(run_id, False)
41
-
42
- def cleanup_run(self, run_id: str) -> None:
43
- """Remove a run from tracking (called when run completes)."""
44
- with self._lock:
45
- if run_id in self._cancelled_runs:
46
- del self._cancelled_runs[run_id]
47
-
48
- def raise_if_cancelled(self, run_id: str) -> None:
49
- """Check if a run should be cancelled and raise exception if so."""
50
- if self.is_cancelled(run_id):
51
- logger.info(f"Cancelling run {run_id}")
52
- raise RunCancelledException(f"Run {run_id} was cancelled")
53
-
54
- def get_active_runs(self) -> Dict[str, bool]:
55
- """Get all currently tracked runs and their cancellation status."""
56
- with self._lock:
57
- return self._cancelled_runs.copy()
58
12
 
13
+ def set_cancellation_manager(manager: BaseRunCancellationManager) -> None:
14
+ """Set a custom cancellation manager.
59
15
 
60
- # Global cancellation manager instance
61
- _cancellation_manager = RunCancellationManager()
16
+ Args:
17
+ manager: A BaseRunCancellationManager instance or subclass.
18
+
19
+ Example:
20
+ ```python
21
+ class MyCustomManager(BaseRunCancellationManager):
22
+ ....
23
+
24
+ set_cancellation_manager(MyCustomManager())
25
+ ```
26
+ """
27
+ global _cancellation_manager
28
+ _cancellation_manager = manager
29
+ logger.info(f"Cancellation manager set to {type(manager).__name__}")
30
+
31
+
32
+ def get_cancellation_manager() -> BaseRunCancellationManager:
33
+ """Get the current cancellation manager instance."""
34
+ return _cancellation_manager
62
35
 
63
36
 
64
37
  def register_run(run_id: str) -> None:
@@ -66,16 +39,56 @@ def register_run(run_id: str) -> None:
66
39
  _cancellation_manager.register_run(run_id)
67
40
 
68
41
 
42
+ async def aregister_run(run_id: str) -> None:
43
+ """Register a new run for cancellation tracking (async version)."""
44
+ await _cancellation_manager.aregister_run(run_id)
45
+
46
+
69
47
  def cancel_run(run_id: str) -> bool:
70
48
  """Cancel a run."""
71
49
  return _cancellation_manager.cancel_run(run_id)
72
50
 
73
51
 
52
+ async def acancel_run(run_id: str) -> bool:
53
+ """Cancel a run (async version)."""
54
+ return await _cancellation_manager.acancel_run(run_id)
55
+
56
+
57
+ def is_cancelled(run_id: str) -> bool:
58
+ """Check if a run is cancelled."""
59
+ return _cancellation_manager.is_cancelled(run_id)
60
+
61
+
62
+ async def ais_cancelled(run_id: str) -> bool:
63
+ """Check if a run is cancelled (async version)."""
64
+ return await _cancellation_manager.ais_cancelled(run_id)
65
+
66
+
74
67
  def cleanup_run(run_id: str) -> None:
75
68
  """Clean up cancellation tracking for a completed run."""
76
69
  _cancellation_manager.cleanup_run(run_id)
77
70
 
78
71
 
72
+ async def acleanup_run(run_id: str) -> None:
73
+ """Clean up cancellation tracking for a completed run (async version)."""
74
+ await _cancellation_manager.acleanup_run(run_id)
75
+
76
+
79
77
  def raise_if_cancelled(run_id: str) -> None:
80
78
  """Check if a run should be cancelled and raise exception if so."""
81
79
  _cancellation_manager.raise_if_cancelled(run_id)
80
+
81
+
82
+ async def araise_if_cancelled(run_id: str) -> None:
83
+ """Check if a run should be cancelled and raise exception if so (async version)."""
84
+ await _cancellation_manager.araise_if_cancelled(run_id)
85
+
86
+
87
+ def get_active_runs() -> Dict[str, bool]:
88
+ """Get all currently tracked runs and their cancellation status."""
89
+ return _cancellation_manager.get_active_runs()
90
+
91
+
92
+ async def aget_active_runs() -> Dict[str, bool]:
93
+ """Get all currently tracked runs and their cancellation status (async version)."""
94
+ return await _cancellation_manager.aget_active_runs()
@@ -0,0 +1,9 @@
1
+ from agno.run.cancellation_management.base import BaseRunCancellationManager
2
+ from agno.run.cancellation_management.in_memory_cancellation_manager import InMemoryRunCancellationManager
3
+ from agno.run.cancellation_management.redis_cancellation_manager import RedisRunCancellationManager
4
+
5
+ __all__ = [
6
+ "BaseRunCancellationManager",
7
+ "InMemoryRunCancellationManager",
8
+ "RedisRunCancellationManager",
9
+ ]
@@ -0,0 +1,78 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Dict
3
+
4
+
5
+ class BaseRunCancellationManager(ABC):
6
+ """Manages cancellation state for agent runs.
7
+
8
+ This class can be extended to implement custom cancellation logic.
9
+ Use set_cancellation_manager() to replace the global instance with your own.
10
+ """
11
+
12
+ @abstractmethod
13
+ def register_run(self, run_id: str) -> None:
14
+ """Register a new run as not cancelled."""
15
+ pass
16
+
17
+ @abstractmethod
18
+ async def aregister_run(self, run_id: str) -> None:
19
+ """Register a new run as not cancelled (async version)."""
20
+ pass
21
+
22
+ @abstractmethod
23
+ def cancel_run(self, run_id: str) -> bool:
24
+ """Cancel a run by marking it as cancelled.
25
+
26
+ Returns:
27
+ bool: True if run was found and cancelled, False if run not found.
28
+ """
29
+ pass
30
+
31
+ @abstractmethod
32
+ async def acancel_run(self, run_id: str) -> bool:
33
+ """Cancel a run by marking it as cancelled (async version).
34
+
35
+ Returns:
36
+ bool: True if run was found and cancelled, False if run not found.
37
+ """
38
+ pass
39
+
40
+ @abstractmethod
41
+ def is_cancelled(self, run_id: str) -> bool:
42
+ """Check if a run is cancelled."""
43
+ pass
44
+
45
+ @abstractmethod
46
+ async def ais_cancelled(self, run_id: str) -> bool:
47
+ """Check if a run is cancelled (async version)."""
48
+ pass
49
+
50
+ @abstractmethod
51
+ def cleanup_run(self, run_id: str) -> None:
52
+ """Remove a run from tracking (called when run completes)."""
53
+ pass
54
+
55
+ @abstractmethod
56
+ async def acleanup_run(self, run_id: str) -> None:
57
+ """Remove a run from tracking (called when run completes) (async version)."""
58
+ pass
59
+
60
+ @abstractmethod
61
+ def raise_if_cancelled(self, run_id: str) -> None:
62
+ """Check if a run should be cancelled and raise exception if so."""
63
+ pass
64
+
65
+ @abstractmethod
66
+ async def araise_if_cancelled(self, run_id: str) -> None:
67
+ """Check if a run should be cancelled and raise exception if so (async version)."""
68
+ pass
69
+
70
+ @abstractmethod
71
+ def get_active_runs(self) -> Dict[str, bool]:
72
+ """Get all currently tracked runs and their cancellation status."""
73
+ pass
74
+
75
+ @abstractmethod
76
+ async def aget_active_runs(self) -> Dict[str, bool]:
77
+ """Get all currently tracked runs and their cancellation status (async version)."""
78
+ pass
@@ -0,0 +1,100 @@
1
+ """Run cancellation management."""
2
+
3
+ import asyncio
4
+ import threading
5
+ from typing import Dict
6
+
7
+ from agno.exceptions import RunCancelledException
8
+ from agno.run.cancellation_management.base import BaseRunCancellationManager
9
+ from agno.utils.log import logger
10
+
11
+
12
+ class InMemoryRunCancellationManager(BaseRunCancellationManager):
13
+ def __init__(self):
14
+ self._cancelled_runs: Dict[str, bool] = {}
15
+ self._lock = threading.Lock()
16
+ self._async_lock = asyncio.Lock()
17
+
18
+ def register_run(self, run_id: str) -> None:
19
+ """Register a new run as not cancelled."""
20
+ with self._lock:
21
+ self._cancelled_runs[run_id] = False
22
+
23
+ async def aregister_run(self, run_id: str) -> None:
24
+ """Register a new run as not cancelled (async version)."""
25
+ async with self._async_lock:
26
+ self._cancelled_runs[run_id] = False
27
+
28
+ def cancel_run(self, run_id: str) -> bool:
29
+ """Cancel a run by marking it as cancelled.
30
+
31
+ Returns:
32
+ bool: True if run was found and cancelled, False if run not found.
33
+ """
34
+ with self._lock:
35
+ if run_id in self._cancelled_runs:
36
+ self._cancelled_runs[run_id] = True
37
+ logger.info(f"Run {run_id} marked for cancellation")
38
+ return True
39
+ else:
40
+ logger.warning(f"Attempted to cancel unknown run {run_id}")
41
+ return False
42
+
43
+ async def acancel_run(self, run_id: str) -> bool:
44
+ """Cancel a run by marking it as cancelled (async version).
45
+
46
+ Returns:
47
+ bool: True if run was found and cancelled, False if run not found.
48
+ """
49
+ async with self._async_lock:
50
+ if run_id in self._cancelled_runs:
51
+ self._cancelled_runs[run_id] = True
52
+ logger.info(f"Run {run_id} marked for cancellation")
53
+ return True
54
+ else:
55
+ logger.warning(f"Attempted to cancel unknown run {run_id}")
56
+ return False
57
+
58
+ def is_cancelled(self, run_id: str) -> bool:
59
+ """Check if a run is cancelled."""
60
+ with self._lock:
61
+ return self._cancelled_runs.get(run_id, False)
62
+
63
+ async def ais_cancelled(self, run_id: str) -> bool:
64
+ """Check if a run is cancelled (async version)."""
65
+ async with self._async_lock:
66
+ return self._cancelled_runs.get(run_id, False)
67
+
68
+ def cleanup_run(self, run_id: str) -> None:
69
+ """Remove a run from tracking (called when run completes)."""
70
+ with self._lock:
71
+ if run_id in self._cancelled_runs:
72
+ del self._cancelled_runs[run_id]
73
+
74
+ async def acleanup_run(self, run_id: str) -> None:
75
+ """Remove a run from tracking (called when run completes) (async version)."""
76
+ async with self._async_lock:
77
+ if run_id in self._cancelled_runs:
78
+ del self._cancelled_runs[run_id]
79
+
80
+ def raise_if_cancelled(self, run_id: str) -> None:
81
+ """Check if a run should be cancelled and raise exception if so."""
82
+ if self.is_cancelled(run_id):
83
+ logger.info(f"Cancelling run {run_id}")
84
+ raise RunCancelledException(f"Run {run_id} was cancelled")
85
+
86
+ async def araise_if_cancelled(self, run_id: str) -> None:
87
+ """Check if a run should be cancelled and raise exception if so (async version)."""
88
+ if await self.ais_cancelled(run_id):
89
+ logger.info(f"Cancelling run {run_id}")
90
+ raise RunCancelledException(f"Run {run_id} was cancelled")
91
+
92
+ def get_active_runs(self) -> Dict[str, bool]:
93
+ """Get all currently tracked runs and their cancellation status."""
94
+ with self._lock:
95
+ return self._cancelled_runs.copy()
96
+
97
+ async def aget_active_runs(self) -> Dict[str, bool]:
98
+ """Get all currently tracked runs and their cancellation status (async version)."""
99
+ async with self._async_lock:
100
+ return self._cancelled_runs.copy()
@@ -0,0 +1,236 @@
1
+ """Redis-based run cancellation management."""
2
+
3
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Union
4
+
5
+ from agno.exceptions import RunCancelledException
6
+ from agno.run.cancellation_management.base import BaseRunCancellationManager
7
+ from agno.utils.log import logger
8
+
9
+ # Defer import error until class instantiation
10
+ _redis_available = True
11
+ _redis_import_error: Optional[str] = None
12
+
13
+ try:
14
+ from redis import Redis, RedisCluster
15
+ from redis.asyncio import Redis as AsyncRedis
16
+ from redis.asyncio import RedisCluster as AsyncRedisCluster
17
+ except ImportError:
18
+ _redis_available = False
19
+ _redis_import_error = "`redis` not installed. Please install it using `pip install redis`"
20
+ # Type hints for when redis is not installed
21
+ if TYPE_CHECKING:
22
+ from redis import Redis, RedisCluster
23
+ from redis.asyncio import Redis as AsyncRedis
24
+ from redis.asyncio import RedisCluster as AsyncRedisCluster
25
+ else:
26
+ Redis = Any
27
+ RedisCluster = Any
28
+ AsyncRedis = Any
29
+ AsyncRedisCluster = Any
30
+
31
+
32
+ class RedisRunCancellationManager(BaseRunCancellationManager):
33
+ """Redis-based cancellation manager for distributed run cancellation.
34
+ This manager stores run cancellation state in Redis, enabling cancellation
35
+ across multiple processes or services.
36
+
37
+ To use: call the set_cancellation_manager function to set the cancellation manager.
38
+ Args:
39
+ redis_client: Sync Redis client for sync methods. Can be Redis or RedisCluster.
40
+ async_redis_client: Async Redis client for async methods. Can be AsyncRedis or AsyncRedisCluster.
41
+ key_prefix: Prefix for Redis keys. Defaults to "agno:run:cancellation:".
42
+ ttl_seconds: TTL for keys in seconds. Defaults to 86400 (1 day).
43
+ Keys auto-expire to prevent orphaned keys if runs aren't cleaned up.
44
+ Set to None to disable expiration.
45
+ """
46
+
47
+ DEFAULT_TTL_SECONDS = 60 * 60 * 24 # 1 day
48
+
49
+ def __init__(
50
+ self,
51
+ redis_client: Optional[Union[Redis, RedisCluster]] = None,
52
+ async_redis_client: Optional[Union[AsyncRedis, AsyncRedisCluster]] = None,
53
+ key_prefix: str = "agno:run:cancellation:",
54
+ ttl_seconds: Optional[int] = DEFAULT_TTL_SECONDS,
55
+ ):
56
+ if not _redis_available:
57
+ raise ImportError(_redis_import_error)
58
+
59
+ super().__init__()
60
+ self.redis_client = redis_client
61
+ self.async_redis_client = async_redis_client
62
+ self.key_prefix = key_prefix
63
+ self.ttl_seconds = ttl_seconds
64
+
65
+ if redis_client is None and async_redis_client is None:
66
+ raise ValueError("At least one of redis_client or async_redis_client must be provided")
67
+
68
+ def _get_key(self, run_id: str) -> str:
69
+ """Get the Redis key for a run ID."""
70
+ return f"{self.key_prefix}{run_id}"
71
+
72
+ def _ensure_sync_client(self) -> Union[Redis, RedisCluster]:
73
+ """Ensure sync client is available."""
74
+ if self.redis_client is None:
75
+ raise RuntimeError("Sync Redis client not provided. Use async methods or provide a sync client.")
76
+ return self.redis_client
77
+
78
+ def _ensure_async_client(self) -> Union[AsyncRedis, AsyncRedisCluster]:
79
+ """Ensure async client is available."""
80
+ if self.async_redis_client is None:
81
+ raise RuntimeError("Async Redis client not provided. Use sync methods or provide an async client.")
82
+ return self.async_redis_client
83
+
84
+ def register_run(self, run_id: str) -> None:
85
+ """Register a new run as not cancelled."""
86
+ client = self._ensure_sync_client()
87
+ key = self._get_key(run_id)
88
+ client.set(key, "0", ex=self.ttl_seconds)
89
+
90
+ async def aregister_run(self, run_id: str) -> None:
91
+ """Register a new run as not cancelled (async version)."""
92
+ client = self._ensure_async_client()
93
+ key = self._get_key(run_id)
94
+ await client.set(key, "0", ex=self.ttl_seconds)
95
+
96
+ def cancel_run(self, run_id: str) -> bool:
97
+ """Cancel a run by marking it as cancelled.
98
+
99
+ Returns:
100
+ bool: True if run was found and cancelled, False if run not found.
101
+ """
102
+ client = self._ensure_sync_client()
103
+ key = self._get_key(run_id)
104
+
105
+ # Atomically set to "1" only if key exists (XX flag)
106
+ result = client.set(key, "1", ex=self.ttl_seconds, xx=True)
107
+
108
+ if result:
109
+ logger.info(f"Run {run_id} marked for cancellation")
110
+ return True
111
+ else:
112
+ logger.warning(f"Attempted to cancel unknown run {run_id}")
113
+ return False
114
+
115
+ async def acancel_run(self, run_id: str) -> bool:
116
+ """Cancel a run by marking it as cancelled (async version).
117
+
118
+ Returns:
119
+ bool: True if run was found and cancelled, False if run not found.
120
+ """
121
+ client = self._ensure_async_client()
122
+ key = self._get_key(run_id)
123
+
124
+ # Atomically set to "1" only if key exists (XX flag)
125
+ result = await client.set(key, "1", ex=self.ttl_seconds, xx=True)
126
+
127
+ if result:
128
+ logger.info(f"Run {run_id} marked for cancellation")
129
+ return True
130
+ else:
131
+ logger.warning(f"Attempted to cancel unknown run {run_id}")
132
+ return False
133
+
134
+ def is_cancelled(self, run_id: str) -> bool:
135
+ """Check if a run is cancelled."""
136
+ client = self._ensure_sync_client()
137
+ key = self._get_key(run_id)
138
+ value = client.get(key)
139
+ if value is None:
140
+ return False
141
+ # Redis returns bytes, handle both bytes and str
142
+ if isinstance(value, bytes):
143
+ return value == b"1"
144
+ return value == "1"
145
+
146
+ async def ais_cancelled(self, run_id: str) -> bool:
147
+ """Check if a run is cancelled (async version)."""
148
+ client = self._ensure_async_client()
149
+ key = self._get_key(run_id)
150
+ value = await client.get(key)
151
+ if value is None:
152
+ return False
153
+ # Redis returns bytes, handle both bytes and str
154
+ if isinstance(value, bytes):
155
+ return value == b"1"
156
+ return value == "1"
157
+
158
+ def cleanup_run(self, run_id: str) -> None:
159
+ """Remove a run from tracking (called when run completes)."""
160
+ client = self._ensure_sync_client()
161
+ key = self._get_key(run_id)
162
+ client.delete(key)
163
+
164
+ async def acleanup_run(self, run_id: str) -> None:
165
+ """Remove a run from tracking (called when run completes) (async version)."""
166
+ client = self._ensure_async_client()
167
+ key = self._get_key(run_id)
168
+ await client.delete(key)
169
+
170
+ def raise_if_cancelled(self, run_id: str) -> None:
171
+ """Check if a run should be cancelled and raise exception if so."""
172
+ if self.is_cancelled(run_id):
173
+ logger.info(f"Cancelling run {run_id}")
174
+ raise RunCancelledException(f"Run {run_id} was cancelled")
175
+
176
+ async def araise_if_cancelled(self, run_id: str) -> None:
177
+ """Check if a run should be cancelled and raise exception if so (async version)."""
178
+ if await self.ais_cancelled(run_id):
179
+ logger.info(f"Cancelling run {run_id}")
180
+ raise RunCancelledException(f"Run {run_id} was cancelled")
181
+
182
+ def get_active_runs(self) -> Dict[str, bool]:
183
+ """Get all currently tracked runs and their cancellation status.
184
+
185
+ Note: Uses scan_iter which works correctly with both standalone Redis
186
+ and Redis Cluster (scans all nodes in cluster mode).
187
+ """
188
+ client = self._ensure_sync_client()
189
+ result: Dict[str, bool] = {}
190
+
191
+ # scan_iter handles cluster mode correctly (scans all nodes)
192
+ pattern = f"{self.key_prefix}*"
193
+ for key in client.scan_iter(match=pattern, count=100):
194
+ # Extract run_id from key
195
+ if isinstance(key, bytes):
196
+ key = key.decode("utf-8")
197
+ run_id = key[len(self.key_prefix) :]
198
+
199
+ # Get value
200
+ value = client.get(key)
201
+ if value is not None:
202
+ if isinstance(value, bytes):
203
+ is_cancelled = value == b"1"
204
+ else:
205
+ is_cancelled = value == "1"
206
+ result[run_id] = is_cancelled
207
+
208
+ return result
209
+
210
+ async def aget_active_runs(self) -> Dict[str, bool]:
211
+ """Get all currently tracked runs and their cancellation status (async version).
212
+
213
+ Note: Uses scan_iter which works correctly with both standalone Redis
214
+ and Redis Cluster (scans all nodes in cluster mode).
215
+ """
216
+ client = self._ensure_async_client()
217
+ result: Dict[str, bool] = {}
218
+
219
+ # scan_iter handles cluster mode correctly (scans all nodes)
220
+ pattern = f"{self.key_prefix}*"
221
+ async for key in client.scan_iter(match=pattern, count=100):
222
+ # Extract run_id from key
223
+ if isinstance(key, bytes):
224
+ key = key.decode("utf-8")
225
+ run_id = key[len(self.key_prefix) :]
226
+
227
+ # Get value
228
+ value = await client.get(key)
229
+ if value is not None:
230
+ if isinstance(value, bytes):
231
+ is_cancelled = value == b"1"
232
+ else:
233
+ is_cancelled = value == "1"
234
+ result[run_id] = is_cancelled
235
+
236
+ return result