offwork 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. offwork/__init__.py +167 -0
  2. offwork/__main__.py +770 -0
  3. offwork/_venv.py +174 -0
  4. offwork/core/__init__.py +15 -0
  5. offwork/core/errors.py +83 -0
  6. offwork/core/models.py +174 -0
  7. offwork/core/pairing.py +389 -0
  8. offwork/core/progress.py +91 -0
  9. offwork/core/signing.py +91 -0
  10. offwork/core/task.py +520 -0
  11. offwork/core/token.py +184 -0
  12. offwork/core/version.py +10 -0
  13. offwork/graph/__init__.py +5 -0
  14. offwork/graph/analyzer.py +637 -0
  15. offwork/graph/decorator.py +87 -0
  16. offwork/graph/graph.py +995 -0
  17. offwork/graph/store.py +500 -0
  18. offwork/graph/tracing.py +429 -0
  19. offwork/py.typed +0 -0
  20. offwork/typing.py +48 -0
  21. offwork/worker/__init__.py +18 -0
  22. offwork/worker/backends/__init__.py +3 -0
  23. offwork/worker/backends/base.py +149 -0
  24. offwork/worker/backends/http.py +237 -0
  25. offwork/worker/backends/local.py +452 -0
  26. offwork/worker/backends/rabbitmq.py +410 -0
  27. offwork/worker/backends/redis.py +175 -0
  28. offwork/worker/deps.py +365 -0
  29. offwork/worker/remote.py +793 -0
  30. offwork/worker/result.py +276 -0
  31. offwork/worker/sandbox/Dockerfile +24 -0
  32. offwork/worker/sandbox/__init__.py +18 -0
  33. offwork/worker/sandbox/_protocol.py +50 -0
  34. offwork/worker/sandbox/docker.py +438 -0
  35. offwork/worker/sandbox/guest_agent.py +622 -0
  36. offwork/worker/schedule.py +26 -0
  37. offwork/worker/worker.py +263 -0
  38. offwork-0.4.0.dist-info/METADATA +143 -0
  39. offwork-0.4.0.dist-info/RECORD +42 -0
  40. offwork-0.4.0.dist-info/WHEEL +4 -0
  41. offwork-0.4.0.dist-info/entry_points.txt +3 -0
  42. offwork-0.4.0.dist-info/licenses/LICENSE +661 -0
@@ -0,0 +1,410 @@
1
+ """RabbitMQ backend for multi-machine task distribution.
2
+
3
+ Uses ``aio-pika`` (async AMQP 0-9-1 client) for task dispatch, result
4
+ routing, heartbeats, cancellation, and progress.
5
+
6
+ Tasks are dispatched via a durable queue. Per-task results use dedicated
7
+ queues with message TTL. Heartbeats, cancellation flags, and progress
8
+ data are stored in single-message queues (``x-max-length: 1``) that
9
+ behave like key-value slots. Result notifications use a fanout exchange.
10
+
11
+ URL scheme: ``amqp://`` or ``amqps://`` (e.g. ``amqp://guest:guest@localhost/``)
12
+ """
13
+ import time
14
+ import hashlib
15
+ import asyncio
16
+ import contextlib
17
+ from typing import Any
18
+ from collections.abc import AsyncIterator
19
+
20
+ try:
21
+ import aio_pika
22
+ except ImportError:
23
+ raise ImportError(
24
+ "aio-pika package is required for RabbitMQBackend. "
25
+ "Install it with: pip install aio-pika"
26
+ ) from None
27
+
28
+ from offwork.worker.backends.base import Backend
29
+
30
+
31
+ class RabbitMQBackend(Backend):
32
+ """RabbitMQ-backed transport using ``aio-pika``.
33
+
34
+ Parameters
35
+ ----------
36
+ url
37
+ AMQP connection URL (e.g. ``amqp://guest:guest@localhost/``).
38
+ task_queue
39
+ Name of the durable task queue.
40
+ result_ttl
41
+ Seconds before result messages expire.
42
+ """
43
+
44
+ TASK_QUEUE = "offwork.tasks"
45
+ RESULT_PREFIX = "offwork.result."
46
+ HEARTBEAT_PREFIX = "offwork.hb."
47
+ CANCEL_PREFIX = "offwork.cancel."
48
+ PROGRESS_PREFIX = "offwork.progress."
49
+ SCHEDULE_PREFIX = "offwork.schedule."
50
+ THROTTLE_PREFIX = "offwork.throttle."
51
+ NOTIFY_EXCHANGE = "offwork.notify"
52
+
53
+ DEFAULT_RESULT_TTL = 300 # seconds
54
+ HEARTBEAT_TTL = 30
55
+ CANCEL_TTL = 3600
56
+ PROGRESS_TTL = 300
57
+ SCHEDULE_TTL = 2592000 # 30 days
58
+ # Fixed queue TTL used for throttle queues. Both check_throttle and
59
+ # record_throttle always declare with the same arguments, avoiding
60
+ # RabbitMQ PRECONDITION_FAILED errors. Actual throttle expiry is
61
+ # encoded in the message value (expiry timestamp), not x-message-ttl.
62
+ THROTTLE_QUEUE_TTL = 86400 # 24 hours
63
+
64
+ def __init__(
65
+ self,
66
+ url: str = "amqp://localhost",
67
+ *,
68
+ task_queue: str | None = None,
69
+ result_ttl: int | None = None,
70
+ ) -> None:
71
+ self._url = url
72
+ self._task_queue_name = task_queue or self.TASK_QUEUE
73
+ self._result_ttl = result_ttl or self.DEFAULT_RESULT_TTL
74
+ self._connection: Any = None
75
+ self._channel: Any = None
76
+ self._lock = asyncio.Lock()
77
+
78
+ # -- connection management --------------------------------------------------
79
+
80
+ async def _get_channel(self) -> Any:
81
+ """Return the shared channel, creating connection if needed."""
82
+ async with self._lock:
83
+ return await self._ensure_channel()
84
+
85
+ async def _ensure_channel(self) -> Any:
86
+ """Return the shared channel (caller must hold ``self._lock``)."""
87
+ if self._connection is None or self._connection.is_closed:
88
+ self._connection = await aio_pika.connect_robust(self._url)
89
+ self._channel = None
90
+ if self._channel is None or self._channel.is_closed:
91
+ self._channel = await self._connection.channel()
92
+ return self._channel
93
+
94
+ async def _new_channel(self) -> Any:
95
+ """Create a dedicated channel for long-running operations."""
96
+ async with self._lock:
97
+ if self._connection is None or self._connection.is_closed:
98
+ self._connection = await aio_pika.connect_robust(self._url)
99
+ self._channel = None
100
+ return await self._connection.channel()
101
+
102
+ # -- internal helpers -------------------------------------------------------
103
+
104
+ @staticmethod
105
+ def _safe_suffix(value: str) -> str:
106
+ """Return an AMQP-safe queue-name suffix for *value*.
107
+
108
+ AMQP queue names only allow ``[a-zA-Z0-9-_.:@#,/ ]`` (pamqp enforces
109
+ this in ``Queue.Declare.validate``). Function names can legally
110
+ contain ``<locals>``, ``>``, ``[``, etc., which are rejected.
111
+ Hash any value that would be invalid (or overlong) so it always
112
+ produces a stable, valid suffix.
113
+ """
114
+ # Fast path: short values that already match the AMQP queue-name
115
+ # regex can be used verbatim.
116
+ if len(value) <= 128 and all(
117
+ c.isalnum() or c in "-_.:@#,/ " for c in value
118
+ ):
119
+ return value
120
+ return hashlib.sha256(value.encode()).hexdigest()[:32]
121
+
122
+ @staticmethod
123
+ def _kv_args(ttl_s: int) -> dict[str, int]:
124
+ """Queue arguments for a single-message key-value slot."""
125
+ return {
126
+ "x-message-ttl": ttl_s * 1000,
127
+ "x-max-length": 1,
128
+ "x-expires": ttl_s * 2 * 1000,
129
+ }
130
+
131
+ def _result_args(self) -> dict[str, int]:
132
+ """Queue arguments for a per-task result queue."""
133
+ return {
134
+ "x-message-ttl": self._result_ttl * 1000,
135
+ "x-expires": self._result_ttl * 2 * 1000,
136
+ }
137
+
138
+ async def _declare_queue_robust(
139
+ self, channel: Any, name: str, arguments: dict[str, Any],
140
+ ) -> Any:
141
+ """Declare a queue, recovering from PRECONDITION_FAILED.
142
+
143
+ When RabbitMQ rejects a redeclaration because the existing queue
144
+ has different arguments, it closes the channel. This helper
145
+ catches that, reopens the channel, deletes the stale queue, and
146
+ redeclares it with the new arguments.
147
+ """
148
+ try:
149
+ return await channel.declare_queue(name, arguments=arguments)
150
+ except Exception:
151
+ # Channel is closed by RabbitMQ after PRECONDITION_FAILED.
152
+ # Reopen a fresh channel, purge the stale queue, and retry.
153
+ self._channel = None
154
+ channel = await self._ensure_channel()
155
+ try:
156
+ await channel.queue_delete(name)
157
+ except Exception:
158
+ self._channel = None
159
+ channel = await self._ensure_channel()
160
+ return await channel.declare_queue(name, arguments=arguments)
161
+
162
+ async def _kv_put(
163
+ self, prefix: str, task_id: str, value: str, ttl_s: int,
164
+ ) -> None:
165
+ """Write to a per-task KV queue (``x-max-length: 1`` overwrites)."""
166
+ async with self._lock:
167
+ channel = await self._ensure_channel()
168
+ name = f"{prefix}{task_id}"
169
+ await self._declare_queue_robust(channel, name, self._kv_args(ttl_s))
170
+ await channel.default_exchange.publish(
171
+ aio_pika.Message(value.encode()),
172
+ routing_key=name,
173
+ )
174
+
175
+ async def _kv_get(
176
+ self, prefix: str, task_id: str, ttl_s: int, *, peek: bool = False,
177
+ ) -> str | None:
178
+ """Read from a per-task KV queue.
179
+
180
+ When *peek* is ``True`` the message is nack'd back so future
181
+ reads still see it (used for cancellation flags). Otherwise the
182
+ message is consumed.
183
+ """
184
+ async with self._lock:
185
+ channel = await self._ensure_channel()
186
+ name = f"{prefix}{task_id}"
187
+ queue = await self._declare_queue_robust(
188
+ channel, name, self._kv_args(ttl_s),
189
+ )
190
+ msg = await queue.get(fail=False, no_ack=not peek)
191
+ if msg is None:
192
+ return None
193
+ if peek:
194
+ await msg.nack(requeue=True)
195
+ raw: str = msg.body.decode()
196
+ return raw
197
+
198
+ # -- Backend interface: tasks -----------------------------------------------
199
+
200
+ async def submit(self, task_json: str) -> None:
201
+ async with self._lock:
202
+ channel = await self._ensure_channel()
203
+ await channel.declare_queue(self._task_queue_name, durable=True)
204
+ await channel.default_exchange.publish(
205
+ aio_pika.Message(
206
+ task_json.encode(),
207
+ delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
208
+ ),
209
+ routing_key=self._task_queue_name,
210
+ )
211
+
212
+ async def listen(self) -> AsyncIterator[str]:
213
+ channel = await self._new_channel()
214
+ try:
215
+ await channel.set_qos(prefetch_count=1)
216
+ queue = await channel.declare_queue(
217
+ self._task_queue_name, durable=True,
218
+ )
219
+ async with queue.iterator() as qi:
220
+ async for message in qi:
221
+ async with message.process():
222
+ yield message.body.decode()
223
+ finally:
224
+ with contextlib.suppress(Exception):
225
+ await channel.close()
226
+
227
+ # -- Backend interface: results ---------------------------------------------
228
+
229
+ async def send_result(self, task_id: str, result_json: str) -> None:
230
+ async with self._lock:
231
+ channel = await self._ensure_channel()
232
+ name = f"{self.RESULT_PREFIX}{task_id}"
233
+ await channel.declare_queue(name, arguments=self._result_args())
234
+ await channel.default_exchange.publish(
235
+ aio_pika.Message(result_json.encode()),
236
+ routing_key=name,
237
+ )
238
+
239
+ async def get_result(self, task_id: str, timeout: float | None = None) -> str:
240
+ channel = await self._new_channel()
241
+ try:
242
+ name = f"{self.RESULT_PREFIX}{task_id}"
243
+ queue = await channel.declare_queue(
244
+ name, arguments=self._result_args(),
245
+ )
246
+ future: asyncio.Future[str] = asyncio.get_running_loop().create_future()
247
+
248
+ async def _on_message(msg: Any) -> None:
249
+ await msg.ack()
250
+ if not future.done():
251
+ future.set_result(msg.body.decode())
252
+
253
+ tag = await queue.consume(_on_message)
254
+ try:
255
+ if timeout is not None:
256
+ try:
257
+ return await asyncio.wait_for(future, timeout=timeout)
258
+ except asyncio.TimeoutError:
259
+ raise TimeoutError(
260
+ f"Timed out waiting for result of task {task_id}"
261
+ ) from None
262
+ return await future
263
+ finally:
264
+ with contextlib.suppress(Exception):
265
+ await queue.cancel(tag)
266
+ finally:
267
+ with contextlib.suppress(Exception):
268
+ await channel.close()
269
+
270
+ async def try_get_result(self, task_id: str) -> str | None:
271
+ async with self._lock:
272
+ channel = await self._ensure_channel()
273
+ name = f"{self.RESULT_PREFIX}{task_id}"
274
+ queue = await channel.declare_queue(
275
+ name, arguments=self._result_args(),
276
+ )
277
+ msg = await queue.get(fail=False)
278
+ if msg is None:
279
+ return None
280
+ await msg.ack()
281
+ raw: str = msg.body.decode()
282
+ return raw
283
+
284
+ # -- Heartbeat -------------------------------------------------------------
285
+
286
+ async def send_heartbeat(self, task_id: str) -> None:
287
+ await self._kv_put(
288
+ self.HEARTBEAT_PREFIX, task_id,
289
+ str(time.time()), self.HEARTBEAT_TTL,
290
+ )
291
+
292
+ async def get_heartbeat(self, task_id: str) -> float | None:
293
+ # Consume (ack) the heartbeat rather than peeking. This avoids a
294
+ # RabbitMQ race where nack+requeue bypasses x-max-length=1 and causes
295
+ # the stale heartbeat to be returned on every subsequent poll, making
296
+ # stall detection fire spuriously.
297
+ raw = await self._kv_get(
298
+ self.HEARTBEAT_PREFIX, task_id, self.HEARTBEAT_TTL, peek=False,
299
+ )
300
+ return float(raw) if raw is not None else None
301
+
302
+ # -- Cancellation ----------------------------------------------------------
303
+
304
+ async def cancel_task(self, task_id: str) -> None:
305
+ await self._kv_put(
306
+ self.CANCEL_PREFIX, task_id, "1", self.CANCEL_TTL,
307
+ )
308
+
309
+ async def is_cancelled(self, task_id: str) -> bool:
310
+ raw = await self._kv_get(
311
+ self.CANCEL_PREFIX, task_id, self.CANCEL_TTL, peek=True,
312
+ )
313
+ return raw is not None
314
+
315
+ # -- Progress --------------------------------------------------------------
316
+
317
+ async def send_progress(self, task_id: str, progress_json: str) -> None:
318
+ await self._kv_put(
319
+ self.PROGRESS_PREFIX, task_id, progress_json, self.PROGRESS_TTL,
320
+ )
321
+
322
+ async def get_progress(self, task_id: str) -> str | None:
323
+ return await self._kv_get(
324
+ self.PROGRESS_PREFIX, task_id, self.PROGRESS_TTL, peek=True,
325
+ )
326
+
327
+ # -- Schedule cancellation -------------------------------------------------
328
+
329
+ async def cancel_schedule(self, schedule_id: str) -> None:
330
+ await self._kv_put(
331
+ self.SCHEDULE_PREFIX, schedule_id, "1", self.SCHEDULE_TTL,
332
+ )
333
+
334
+ async def is_schedule_cancelled(self, schedule_id: str) -> bool:
335
+ raw = await self._kv_get(
336
+ self.SCHEDULE_PREFIX, schedule_id, self.SCHEDULE_TTL, peek=True,
337
+ )
338
+ return raw is not None
339
+
340
+ # -- Throttle --------------------------------------------------------------
341
+ # Throttle state is stored as an expiry timestamp in the message body.
342
+ # Both check and record always declare the queue with the same fixed TTL
343
+ # (THROTTLE_QUEUE_TTL) so RabbitMQ never sees conflicting queue arguments.
344
+ # The actual cooldown window is enforced by comparing time.time() against
345
+ # the stored expiry value, not by x-message-ttl.
346
+
347
+ async def check_throttle(self, function_name: str) -> bool:
348
+ # Uses THROTTLE_QUEUE_TTL for both declaration and storage so that
349
+ # check_throttle and record_throttle always declare with identical
350
+ # queue arguments (avoiding PRECONDITION_FAILED on redeclaration).
351
+ # The actual throttle deadline is stored as a UNIX timestamp in the
352
+ # message body. Function names can contain characters rejected by
353
+ # the AMQP queue-name grammar (e.g. ``<locals>``), so we hash them.
354
+ raw = await self._kv_get(
355
+ self.THROTTLE_PREFIX, self._safe_suffix(function_name),
356
+ self.THROTTLE_QUEUE_TTL, peek=True,
357
+ )
358
+ if raw is None:
359
+ return True # no throttle entry → execution allowed
360
+ return time.time() >= float(raw) # allowed only if past expiry
361
+
362
+ async def record_throttle(
363
+ self, function_name: str, throttle_seconds: float,
364
+ ) -> None:
365
+ expiry = time.time() + throttle_seconds
366
+ await self._kv_put(
367
+ self.THROTTLE_PREFIX, self._safe_suffix(function_name),
368
+ str(expiry), self.THROTTLE_QUEUE_TTL,
369
+ )
370
+
371
+ # -- Result notifications --------------------------------------------------
372
+
373
+ async def notify_result(self, task_id: str) -> None:
374
+ async with self._lock:
375
+ channel = await self._ensure_channel()
376
+ exchange = await channel.declare_exchange(
377
+ self.NOTIFY_EXCHANGE, aio_pika.ExchangeType.FANOUT,
378
+ )
379
+ await exchange.publish(
380
+ aio_pika.Message(task_id.encode()),
381
+ routing_key="",
382
+ )
383
+
384
+ async def subscribe_results(self) -> AsyncIterator[str]:
385
+ channel = await self._new_channel()
386
+ try:
387
+ exchange = await channel.declare_exchange(
388
+ self.NOTIFY_EXCHANGE, aio_pika.ExchangeType.FANOUT,
389
+ )
390
+ queue = await channel.declare_queue(exclusive=True)
391
+ await queue.bind(exchange)
392
+ async with queue.iterator() as qi:
393
+ async for message in qi:
394
+ async with message.process():
395
+ yield message.body.decode()
396
+ finally:
397
+ with contextlib.suppress(Exception):
398
+ await channel.close()
399
+
400
+ # -- Lifecycle -------------------------------------------------------------
401
+
402
+ async def close(self) -> None:
403
+ if self._channel is not None:
404
+ with contextlib.suppress(Exception):
405
+ await self._channel.close()
406
+ self._channel = None
407
+ if self._connection is not None:
408
+ with contextlib.suppress(Exception):
409
+ await self._connection.close()
410
+ self._connection = None
@@ -0,0 +1,175 @@
1
+ """Redis-backed transport using ``RPUSH``/``BLPOP`` for tasks and results."""
2
+
3
+ import time
4
+ import asyncio
5
+ from typing import Any
6
+ from collections.abc import AsyncIterator
7
+
8
+ try:
9
+ import redis.asyncio as _redis
10
+ except ImportError:
11
+ raise ImportError(
12
+ "redis package is required for RedisBackend. "
13
+ "Install it with: pip install redis"
14
+ ) from None
15
+
16
+ from offwork.worker.backends.base import Backend
17
+
18
+
19
+ class RedisBackend(Backend):
20
+ """Redis-backed transport using ``RPUSH``/``BLPOP`` for tasks and results.
21
+
22
+ Parameters
23
+ ----------
24
+ url
25
+ Redis connection URL (e.g. ``redis://localhost:6379``).
26
+ queue_key
27
+ Redis key for the task queue.
28
+ result_ttl
29
+ Seconds before result keys expire.
30
+ """
31
+
32
+ DEFAULT_QUEUE_KEY = "offwork:tasks"
33
+ RESULT_PREFIX = "offwork:result:"
34
+ HEARTBEAT_PREFIX = "offwork:heartbeat:"
35
+ CANCEL_PREFIX = "offwork:cancel:"
36
+ PROGRESS_PREFIX = "offwork:progress:"
37
+ SCHEDULE_PREFIX = "offwork:schedule:"
38
+ THROTTLE_PREFIX = "offwork:throttle:"
39
+ NOTIFY_CHANNEL = "offwork:notify"
40
+ DEFAULT_RESULT_TTL = 300
41
+ HEARTBEAT_TTL = 30
42
+ CANCEL_TTL = 3600
43
+ PROGRESS_TTL = 300
44
+ SCHEDULE_TTL = 2592000 # 30 days
45
+
46
+ def __init__(
47
+ self,
48
+ url: str = "redis://localhost:6379",
49
+ *,
50
+ queue_key: str | None = None,
51
+ result_ttl: int | None = None,
52
+ ) -> None:
53
+ self._redis: Any = _redis.Redis.from_url(url)
54
+ self._queue_key = queue_key or self.DEFAULT_QUEUE_KEY
55
+ self._result_ttl = result_ttl or self.DEFAULT_RESULT_TTL
56
+
57
+ async def submit(self, task_json: str) -> None:
58
+ await self._redis.rpush(self._queue_key, task_json)
59
+
60
+ async def listen(self) -> AsyncIterator[str]:
61
+ """Block on ``BLPOP`` and yield task JSON strings as they arrive."""
62
+ while True:
63
+ result = await self._redis.blpop(self._queue_key)
64
+ if result is None:
65
+ continue
66
+ _, raw = result
67
+ yield raw.decode() if isinstance(raw, bytes) else raw
68
+
69
+ async def send_result(self, task_id: str, result_json: str) -> None:
70
+ key = f"{self.RESULT_PREFIX}{task_id}"
71
+ await self._redis.rpush(key, result_json)
72
+ await self._redis.expire(key, self._result_ttl)
73
+
74
+ async def get_result(self, task_id: str, timeout: float | None = None) -> str:
75
+ key = f"{self.RESULT_PREFIX}{task_id}"
76
+ t = int(timeout) if timeout else 0
77
+ result = await self._redis.blpop(key, timeout=t)
78
+ if result is None:
79
+ raise TimeoutError(
80
+ f"Timed out waiting for result of task {task_id}"
81
+ )
82
+ _, raw = result
83
+ return raw.decode() if isinstance(raw, bytes) else raw
84
+
85
+ async def try_get_result(self, task_id: str) -> str | None:
86
+ """Non-blocking ``LPOP``; returns ``None`` if not yet available."""
87
+ key = f"{self.RESULT_PREFIX}{task_id}"
88
+ raw = await self._redis.lpop(key)
89
+ if raw is None:
90
+ return None
91
+ return raw.decode() if isinstance(raw, bytes) else raw
92
+
93
+ async def send_heartbeat(self, task_id: str) -> None:
94
+ key = f"{self.HEARTBEAT_PREFIX}{task_id}"
95
+ await self._redis.set(key, str(time.time()), ex=self.HEARTBEAT_TTL)
96
+
97
+ async def get_heartbeat(self, task_id: str) -> float | None:
98
+ key = f"{self.HEARTBEAT_PREFIX}{task_id}"
99
+ raw = await self._redis.get(key)
100
+ if raw is None:
101
+ return None
102
+ return float(raw)
103
+
104
+ async def get_heartbeats(self, task_ids: list[str]) -> dict[str, float | None]:
105
+ """Batch fetch via ``MGET`` for efficiency."""
106
+ if not task_ids:
107
+ return {}
108
+ keys = [f"{self.HEARTBEAT_PREFIX}{tid}" for tid in task_ids]
109
+ values = await self._redis.mget(keys)
110
+ return {
111
+ tid: float(v) if v is not None else None
112
+ for tid, v in zip(task_ids, values)
113
+ }
114
+
115
+ async def cancel_task(self, task_id: str) -> None:
116
+ key = f"{self.CANCEL_PREFIX}{task_id}"
117
+ await self._redis.set(key, "1", ex=self.CANCEL_TTL)
118
+
119
+ async def is_cancelled(self, task_id: str) -> bool:
120
+ key = f"{self.CANCEL_PREFIX}{task_id}"
121
+ return int(await self._redis.exists(key)) > 0
122
+
123
+ async def send_progress(self, task_id: str, progress_json: str) -> None:
124
+ key = f"{self.PROGRESS_PREFIX}{task_id}"
125
+ await self._redis.set(key, progress_json, ex=self.PROGRESS_TTL)
126
+
127
+ async def get_progress(self, task_id: str) -> str | None:
128
+ key = f"{self.PROGRESS_PREFIX}{task_id}"
129
+ raw = await self._redis.get(key)
130
+ if raw is None:
131
+ return None
132
+ return raw.decode() if isinstance(raw, bytes) else raw
133
+
134
+ async def cancel_schedule(self, schedule_id: str) -> None:
135
+ key = f"{self.SCHEDULE_PREFIX}{schedule_id}"
136
+ await self._redis.set(key, "1", ex=self.SCHEDULE_TTL)
137
+
138
+ async def is_schedule_cancelled(self, schedule_id: str) -> bool:
139
+ key = f"{self.SCHEDULE_PREFIX}{schedule_id}"
140
+ return int(await self._redis.exists(key)) > 0
141
+
142
+ async def check_throttle(self, function_name: str) -> bool:
143
+ key = f"{self.THROTTLE_PREFIX}{function_name}"
144
+ return int(await self._redis.exists(key)) == 0
145
+
146
+ async def record_throttle(
147
+ self, function_name: str, throttle_seconds: float,
148
+ ) -> None:
149
+ key = f"{self.THROTTLE_PREFIX}{function_name}"
150
+ await self._redis.set(key, "1", ex=max(1, int(throttle_seconds)))
151
+
152
+ async def notify_result(self, task_id: str) -> None:
153
+ """Publish task_id on the Pub/Sub notification channel."""
154
+ await self._redis.publish(self.NOTIFY_CHANNEL, task_id)
155
+
156
+ async def subscribe_results(self) -> AsyncIterator[str]:
157
+ """Subscribe to the Pub/Sub channel and yield task IDs on result arrival."""
158
+ pubsub = self._redis.pubsub()
159
+ await pubsub.subscribe(self.NOTIFY_CHANNEL)
160
+ try:
161
+ while True:
162
+ msg = await pubsub.get_message(
163
+ ignore_subscribe_messages=True, timeout=1.0,
164
+ )
165
+ if msg is not None and msg["type"] == "message":
166
+ data = msg["data"]
167
+ yield data.decode() if isinstance(data, bytes) else data
168
+ elif msg is None:
169
+ await asyncio.sleep(0.01)
170
+ finally:
171
+ await pubsub.unsubscribe(self.NOTIFY_CHANNEL)
172
+ await pubsub.aclose()
173
+
174
+ async def close(self) -> None:
175
+ await self._redis.aclose()