offwork 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- offwork/__init__.py +167 -0
- offwork/__main__.py +770 -0
- offwork/_venv.py +174 -0
- offwork/core/__init__.py +15 -0
- offwork/core/errors.py +83 -0
- offwork/core/models.py +174 -0
- offwork/core/pairing.py +389 -0
- offwork/core/progress.py +91 -0
- offwork/core/signing.py +91 -0
- offwork/core/task.py +520 -0
- offwork/core/token.py +184 -0
- offwork/core/version.py +10 -0
- offwork/graph/__init__.py +5 -0
- offwork/graph/analyzer.py +637 -0
- offwork/graph/decorator.py +87 -0
- offwork/graph/graph.py +995 -0
- offwork/graph/store.py +500 -0
- offwork/graph/tracing.py +429 -0
- offwork/py.typed +0 -0
- offwork/typing.py +48 -0
- offwork/worker/__init__.py +18 -0
- offwork/worker/backends/__init__.py +3 -0
- offwork/worker/backends/base.py +149 -0
- offwork/worker/backends/http.py +237 -0
- offwork/worker/backends/local.py +452 -0
- offwork/worker/backends/rabbitmq.py +410 -0
- offwork/worker/backends/redis.py +175 -0
- offwork/worker/deps.py +365 -0
- offwork/worker/remote.py +793 -0
- offwork/worker/result.py +276 -0
- offwork/worker/sandbox/Dockerfile +24 -0
- offwork/worker/sandbox/__init__.py +18 -0
- offwork/worker/sandbox/_protocol.py +50 -0
- offwork/worker/sandbox/docker.py +438 -0
- offwork/worker/sandbox/guest_agent.py +622 -0
- offwork/worker/schedule.py +26 -0
- offwork/worker/worker.py +263 -0
- offwork-0.1.0.dist-info/METADATA +143 -0
- offwork-0.1.0.dist-info/RECORD +42 -0
- offwork-0.1.0.dist-info/WHEEL +4 -0
- offwork-0.1.0.dist-info/entry_points.txt +3 -0
- offwork-0.1.0.dist-info/licenses/LICENSE +661 -0
|
@@ -0,0 +1,410 @@
|
|
|
1
|
+
"""RabbitMQ backend for multi-machine task distribution.
|
|
2
|
+
|
|
3
|
+
Uses ``aio-pika`` (async AMQP 0-9-1 client) for task dispatch, result
|
|
4
|
+
routing, heartbeats, cancellation, and progress.
|
|
5
|
+
|
|
6
|
+
Tasks are dispatched via a durable queue. Per-task results use dedicated
|
|
7
|
+
queues with message TTL. Heartbeats, cancellation flags, and progress
|
|
8
|
+
data are stored in single-message queues (``x-max-length: 1``) that
|
|
9
|
+
behave like key-value slots. Result notifications use a fanout exchange.
|
|
10
|
+
|
|
11
|
+
URL scheme: ``amqp://`` or ``amqps://`` (e.g. ``amqp://guest:guest@localhost/``)
|
|
12
|
+
"""
|
|
13
|
+
import time
|
|
14
|
+
import hashlib
|
|
15
|
+
import asyncio
|
|
16
|
+
import contextlib
|
|
17
|
+
from typing import Any
|
|
18
|
+
from collections.abc import AsyncIterator
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
import aio_pika
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError(
|
|
24
|
+
"aio-pika package is required for RabbitMQBackend. "
|
|
25
|
+
"Install it with: pip install aio-pika"
|
|
26
|
+
) from None
|
|
27
|
+
|
|
28
|
+
from offwork.worker.backends.base import Backend
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RabbitMQBackend(Backend):
|
|
32
|
+
"""RabbitMQ-backed transport using ``aio-pika``.
|
|
33
|
+
|
|
34
|
+
Parameters
|
|
35
|
+
----------
|
|
36
|
+
url
|
|
37
|
+
AMQP connection URL (e.g. ``amqp://guest:guest@localhost/``).
|
|
38
|
+
task_queue
|
|
39
|
+
Name of the durable task queue.
|
|
40
|
+
result_ttl
|
|
41
|
+
Seconds before result messages expire.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
TASK_QUEUE = "offwork.tasks"
|
|
45
|
+
RESULT_PREFIX = "offwork.result."
|
|
46
|
+
HEARTBEAT_PREFIX = "offwork.hb."
|
|
47
|
+
CANCEL_PREFIX = "offwork.cancel."
|
|
48
|
+
PROGRESS_PREFIX = "offwork.progress."
|
|
49
|
+
SCHEDULE_PREFIX = "offwork.schedule."
|
|
50
|
+
THROTTLE_PREFIX = "offwork.throttle."
|
|
51
|
+
NOTIFY_EXCHANGE = "offwork.notify"
|
|
52
|
+
|
|
53
|
+
DEFAULT_RESULT_TTL = 300 # seconds
|
|
54
|
+
HEARTBEAT_TTL = 30
|
|
55
|
+
CANCEL_TTL = 3600
|
|
56
|
+
PROGRESS_TTL = 300
|
|
57
|
+
SCHEDULE_TTL = 2592000 # 30 days
|
|
58
|
+
# Fixed queue TTL used for throttle queues. Both check_throttle and
|
|
59
|
+
# record_throttle always declare with the same arguments, avoiding
|
|
60
|
+
# RabbitMQ PRECONDITION_FAILED errors. Actual throttle expiry is
|
|
61
|
+
# encoded in the message value (expiry timestamp), not x-message-ttl.
|
|
62
|
+
THROTTLE_QUEUE_TTL = 86400 # 24 hours
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
url: str = "amqp://localhost",
|
|
67
|
+
*,
|
|
68
|
+
task_queue: str | None = None,
|
|
69
|
+
result_ttl: int | None = None,
|
|
70
|
+
) -> None:
|
|
71
|
+
self._url = url
|
|
72
|
+
self._task_queue_name = task_queue or self.TASK_QUEUE
|
|
73
|
+
self._result_ttl = result_ttl or self.DEFAULT_RESULT_TTL
|
|
74
|
+
self._connection: Any = None
|
|
75
|
+
self._channel: Any = None
|
|
76
|
+
self._lock = asyncio.Lock()
|
|
77
|
+
|
|
78
|
+
# -- connection management --------------------------------------------------
|
|
79
|
+
|
|
80
|
+
async def _get_channel(self) -> Any:
|
|
81
|
+
"""Return the shared channel, creating connection if needed."""
|
|
82
|
+
async with self._lock:
|
|
83
|
+
return await self._ensure_channel()
|
|
84
|
+
|
|
85
|
+
async def _ensure_channel(self) -> Any:
|
|
86
|
+
"""Return the shared channel (caller must hold ``self._lock``)."""
|
|
87
|
+
if self._connection is None or self._connection.is_closed:
|
|
88
|
+
self._connection = await aio_pika.connect_robust(self._url)
|
|
89
|
+
self._channel = None
|
|
90
|
+
if self._channel is None or self._channel.is_closed:
|
|
91
|
+
self._channel = await self._connection.channel()
|
|
92
|
+
return self._channel
|
|
93
|
+
|
|
94
|
+
async def _new_channel(self) -> Any:
|
|
95
|
+
"""Create a dedicated channel for long-running operations."""
|
|
96
|
+
async with self._lock:
|
|
97
|
+
if self._connection is None or self._connection.is_closed:
|
|
98
|
+
self._connection = await aio_pika.connect_robust(self._url)
|
|
99
|
+
self._channel = None
|
|
100
|
+
return await self._connection.channel()
|
|
101
|
+
|
|
102
|
+
# -- internal helpers -------------------------------------------------------
|
|
103
|
+
|
|
104
|
+
@staticmethod
|
|
105
|
+
def _safe_suffix(value: str) -> str:
|
|
106
|
+
"""Return an AMQP-safe queue-name suffix for *value*.
|
|
107
|
+
|
|
108
|
+
AMQP queue names only allow ``[a-zA-Z0-9-_.:@#,/ ]`` (pamqp enforces
|
|
109
|
+
this in ``Queue.Declare.validate``). Function names can legally
|
|
110
|
+
contain ``<locals>``, ``>``, ``[``, etc., which are rejected.
|
|
111
|
+
Hash any value that would be invalid (or overlong) so it always
|
|
112
|
+
produces a stable, valid suffix.
|
|
113
|
+
"""
|
|
114
|
+
# Fast path: short values that already match the AMQP queue-name
|
|
115
|
+
# regex can be used verbatim.
|
|
116
|
+
if len(value) <= 128 and all(
|
|
117
|
+
c.isalnum() or c in "-_.:@#,/ " for c in value
|
|
118
|
+
):
|
|
119
|
+
return value
|
|
120
|
+
return hashlib.sha256(value.encode()).hexdigest()[:32]
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def _kv_args(ttl_s: int) -> dict[str, int]:
|
|
124
|
+
"""Queue arguments for a single-message key-value slot."""
|
|
125
|
+
return {
|
|
126
|
+
"x-message-ttl": ttl_s * 1000,
|
|
127
|
+
"x-max-length": 1,
|
|
128
|
+
"x-expires": ttl_s * 2 * 1000,
|
|
129
|
+
}
|
|
130
|
+
|
|
131
|
+
def _result_args(self) -> dict[str, int]:
|
|
132
|
+
"""Queue arguments for a per-task result queue."""
|
|
133
|
+
return {
|
|
134
|
+
"x-message-ttl": self._result_ttl * 1000,
|
|
135
|
+
"x-expires": self._result_ttl * 2 * 1000,
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
async def _declare_queue_robust(
|
|
139
|
+
self, channel: Any, name: str, arguments: dict[str, Any],
|
|
140
|
+
) -> Any:
|
|
141
|
+
"""Declare a queue, recovering from PRECONDITION_FAILED.
|
|
142
|
+
|
|
143
|
+
When RabbitMQ rejects a redeclaration because the existing queue
|
|
144
|
+
has different arguments, it closes the channel. This helper
|
|
145
|
+
catches that, reopens the channel, deletes the stale queue, and
|
|
146
|
+
redeclares it with the new arguments.
|
|
147
|
+
"""
|
|
148
|
+
try:
|
|
149
|
+
return await channel.declare_queue(name, arguments=arguments)
|
|
150
|
+
except Exception:
|
|
151
|
+
# Channel is closed by RabbitMQ after PRECONDITION_FAILED.
|
|
152
|
+
# Reopen a fresh channel, purge the stale queue, and retry.
|
|
153
|
+
self._channel = None
|
|
154
|
+
channel = await self._ensure_channel()
|
|
155
|
+
try:
|
|
156
|
+
await channel.queue_delete(name)
|
|
157
|
+
except Exception:
|
|
158
|
+
self._channel = None
|
|
159
|
+
channel = await self._ensure_channel()
|
|
160
|
+
return await channel.declare_queue(name, arguments=arguments)
|
|
161
|
+
|
|
162
|
+
async def _kv_put(
|
|
163
|
+
self, prefix: str, task_id: str, value: str, ttl_s: int,
|
|
164
|
+
) -> None:
|
|
165
|
+
"""Write to a per-task KV queue (``x-max-length: 1`` overwrites)."""
|
|
166
|
+
async with self._lock:
|
|
167
|
+
channel = await self._ensure_channel()
|
|
168
|
+
name = f"{prefix}{task_id}"
|
|
169
|
+
await self._declare_queue_robust(channel, name, self._kv_args(ttl_s))
|
|
170
|
+
await channel.default_exchange.publish(
|
|
171
|
+
aio_pika.Message(value.encode()),
|
|
172
|
+
routing_key=name,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
async def _kv_get(
|
|
176
|
+
self, prefix: str, task_id: str, ttl_s: int, *, peek: bool = False,
|
|
177
|
+
) -> str | None:
|
|
178
|
+
"""Read from a per-task KV queue.
|
|
179
|
+
|
|
180
|
+
When *peek* is ``True`` the message is nack'd back so future
|
|
181
|
+
reads still see it (used for cancellation flags). Otherwise the
|
|
182
|
+
message is consumed.
|
|
183
|
+
"""
|
|
184
|
+
async with self._lock:
|
|
185
|
+
channel = await self._ensure_channel()
|
|
186
|
+
name = f"{prefix}{task_id}"
|
|
187
|
+
queue = await self._declare_queue_robust(
|
|
188
|
+
channel, name, self._kv_args(ttl_s),
|
|
189
|
+
)
|
|
190
|
+
msg = await queue.get(fail=False, no_ack=not peek)
|
|
191
|
+
if msg is None:
|
|
192
|
+
return None
|
|
193
|
+
if peek:
|
|
194
|
+
await msg.nack(requeue=True)
|
|
195
|
+
raw: str = msg.body.decode()
|
|
196
|
+
return raw
|
|
197
|
+
|
|
198
|
+
# -- Backend interface: tasks -----------------------------------------------
|
|
199
|
+
|
|
200
|
+
async def submit(self, task_json: str) -> None:
|
|
201
|
+
async with self._lock:
|
|
202
|
+
channel = await self._ensure_channel()
|
|
203
|
+
await channel.declare_queue(self._task_queue_name, durable=True)
|
|
204
|
+
await channel.default_exchange.publish(
|
|
205
|
+
aio_pika.Message(
|
|
206
|
+
task_json.encode(),
|
|
207
|
+
delivery_mode=aio_pika.DeliveryMode.PERSISTENT,
|
|
208
|
+
),
|
|
209
|
+
routing_key=self._task_queue_name,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
async def listen(self) -> AsyncIterator[str]:
|
|
213
|
+
channel = await self._new_channel()
|
|
214
|
+
try:
|
|
215
|
+
await channel.set_qos(prefetch_count=1)
|
|
216
|
+
queue = await channel.declare_queue(
|
|
217
|
+
self._task_queue_name, durable=True,
|
|
218
|
+
)
|
|
219
|
+
async with queue.iterator() as qi:
|
|
220
|
+
async for message in qi:
|
|
221
|
+
async with message.process():
|
|
222
|
+
yield message.body.decode()
|
|
223
|
+
finally:
|
|
224
|
+
with contextlib.suppress(Exception):
|
|
225
|
+
await channel.close()
|
|
226
|
+
|
|
227
|
+
# -- Backend interface: results ---------------------------------------------
|
|
228
|
+
|
|
229
|
+
async def send_result(self, task_id: str, result_json: str) -> None:
|
|
230
|
+
async with self._lock:
|
|
231
|
+
channel = await self._ensure_channel()
|
|
232
|
+
name = f"{self.RESULT_PREFIX}{task_id}"
|
|
233
|
+
await channel.declare_queue(name, arguments=self._result_args())
|
|
234
|
+
await channel.default_exchange.publish(
|
|
235
|
+
aio_pika.Message(result_json.encode()),
|
|
236
|
+
routing_key=name,
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
async def get_result(self, task_id: str, timeout: float | None = None) -> str:
|
|
240
|
+
channel = await self._new_channel()
|
|
241
|
+
try:
|
|
242
|
+
name = f"{self.RESULT_PREFIX}{task_id}"
|
|
243
|
+
queue = await channel.declare_queue(
|
|
244
|
+
name, arguments=self._result_args(),
|
|
245
|
+
)
|
|
246
|
+
future: asyncio.Future[str] = asyncio.get_running_loop().create_future()
|
|
247
|
+
|
|
248
|
+
async def _on_message(msg: Any) -> None:
|
|
249
|
+
await msg.ack()
|
|
250
|
+
if not future.done():
|
|
251
|
+
future.set_result(msg.body.decode())
|
|
252
|
+
|
|
253
|
+
tag = await queue.consume(_on_message)
|
|
254
|
+
try:
|
|
255
|
+
if timeout is not None:
|
|
256
|
+
try:
|
|
257
|
+
return await asyncio.wait_for(future, timeout=timeout)
|
|
258
|
+
except asyncio.TimeoutError:
|
|
259
|
+
raise TimeoutError(
|
|
260
|
+
f"Timed out waiting for result of task {task_id}"
|
|
261
|
+
) from None
|
|
262
|
+
return await future
|
|
263
|
+
finally:
|
|
264
|
+
with contextlib.suppress(Exception):
|
|
265
|
+
await queue.cancel(tag)
|
|
266
|
+
finally:
|
|
267
|
+
with contextlib.suppress(Exception):
|
|
268
|
+
await channel.close()
|
|
269
|
+
|
|
270
|
+
async def try_get_result(self, task_id: str) -> str | None:
|
|
271
|
+
async with self._lock:
|
|
272
|
+
channel = await self._ensure_channel()
|
|
273
|
+
name = f"{self.RESULT_PREFIX}{task_id}"
|
|
274
|
+
queue = await channel.declare_queue(
|
|
275
|
+
name, arguments=self._result_args(),
|
|
276
|
+
)
|
|
277
|
+
msg = await queue.get(fail=False)
|
|
278
|
+
if msg is None:
|
|
279
|
+
return None
|
|
280
|
+
await msg.ack()
|
|
281
|
+
raw: str = msg.body.decode()
|
|
282
|
+
return raw
|
|
283
|
+
|
|
284
|
+
# -- Heartbeat -------------------------------------------------------------
|
|
285
|
+
|
|
286
|
+
async def send_heartbeat(self, task_id: str) -> None:
|
|
287
|
+
await self._kv_put(
|
|
288
|
+
self.HEARTBEAT_PREFIX, task_id,
|
|
289
|
+
str(time.time()), self.HEARTBEAT_TTL,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
async def get_heartbeat(self, task_id: str) -> float | None:
|
|
293
|
+
# Consume (ack) the heartbeat rather than peeking. This avoids a
|
|
294
|
+
# RabbitMQ race where nack+requeue bypasses x-max-length=1 and causes
|
|
295
|
+
# the stale heartbeat to be returned on every subsequent poll, making
|
|
296
|
+
# stall detection fire spuriously.
|
|
297
|
+
raw = await self._kv_get(
|
|
298
|
+
self.HEARTBEAT_PREFIX, task_id, self.HEARTBEAT_TTL, peek=False,
|
|
299
|
+
)
|
|
300
|
+
return float(raw) if raw is not None else None
|
|
301
|
+
|
|
302
|
+
# -- Cancellation ----------------------------------------------------------
|
|
303
|
+
|
|
304
|
+
async def cancel_task(self, task_id: str) -> None:
|
|
305
|
+
await self._kv_put(
|
|
306
|
+
self.CANCEL_PREFIX, task_id, "1", self.CANCEL_TTL,
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
async def is_cancelled(self, task_id: str) -> bool:
|
|
310
|
+
raw = await self._kv_get(
|
|
311
|
+
self.CANCEL_PREFIX, task_id, self.CANCEL_TTL, peek=True,
|
|
312
|
+
)
|
|
313
|
+
return raw is not None
|
|
314
|
+
|
|
315
|
+
# -- Progress --------------------------------------------------------------
|
|
316
|
+
|
|
317
|
+
async def send_progress(self, task_id: str, progress_json: str) -> None:
|
|
318
|
+
await self._kv_put(
|
|
319
|
+
self.PROGRESS_PREFIX, task_id, progress_json, self.PROGRESS_TTL,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
async def get_progress(self, task_id: str) -> str | None:
|
|
323
|
+
return await self._kv_get(
|
|
324
|
+
self.PROGRESS_PREFIX, task_id, self.PROGRESS_TTL, peek=True,
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
# -- Schedule cancellation -------------------------------------------------
|
|
328
|
+
|
|
329
|
+
async def cancel_schedule(self, schedule_id: str) -> None:
|
|
330
|
+
await self._kv_put(
|
|
331
|
+
self.SCHEDULE_PREFIX, schedule_id, "1", self.SCHEDULE_TTL,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
async def is_schedule_cancelled(self, schedule_id: str) -> bool:
|
|
335
|
+
raw = await self._kv_get(
|
|
336
|
+
self.SCHEDULE_PREFIX, schedule_id, self.SCHEDULE_TTL, peek=True,
|
|
337
|
+
)
|
|
338
|
+
return raw is not None
|
|
339
|
+
|
|
340
|
+
# -- Throttle --------------------------------------------------------------
|
|
341
|
+
# Throttle state is stored as an expiry timestamp in the message body.
|
|
342
|
+
# Both check and record always declare the queue with the same fixed TTL
|
|
343
|
+
# (THROTTLE_QUEUE_TTL) so RabbitMQ never sees conflicting queue arguments.
|
|
344
|
+
# The actual cooldown window is enforced by comparing time.time() against
|
|
345
|
+
# the stored expiry value, not by x-message-ttl.
|
|
346
|
+
|
|
347
|
+
async def check_throttle(self, function_name: str) -> bool:
|
|
348
|
+
# Uses THROTTLE_QUEUE_TTL for both declaration and storage so that
|
|
349
|
+
# check_throttle and record_throttle always declare with identical
|
|
350
|
+
# queue arguments (avoiding PRECONDITION_FAILED on redeclaration).
|
|
351
|
+
# The actual throttle deadline is stored as a UNIX timestamp in the
|
|
352
|
+
# message body. Function names can contain characters rejected by
|
|
353
|
+
# the AMQP queue-name grammar (e.g. ``<locals>``), so we hash them.
|
|
354
|
+
raw = await self._kv_get(
|
|
355
|
+
self.THROTTLE_PREFIX, self._safe_suffix(function_name),
|
|
356
|
+
self.THROTTLE_QUEUE_TTL, peek=True,
|
|
357
|
+
)
|
|
358
|
+
if raw is None:
|
|
359
|
+
return True # no throttle entry → execution allowed
|
|
360
|
+
return time.time() >= float(raw) # allowed only if past expiry
|
|
361
|
+
|
|
362
|
+
async def record_throttle(
|
|
363
|
+
self, function_name: str, throttle_seconds: float,
|
|
364
|
+
) -> None:
|
|
365
|
+
expiry = time.time() + throttle_seconds
|
|
366
|
+
await self._kv_put(
|
|
367
|
+
self.THROTTLE_PREFIX, self._safe_suffix(function_name),
|
|
368
|
+
str(expiry), self.THROTTLE_QUEUE_TTL,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# -- Result notifications --------------------------------------------------
|
|
372
|
+
|
|
373
|
+
async def notify_result(self, task_id: str) -> None:
|
|
374
|
+
async with self._lock:
|
|
375
|
+
channel = await self._ensure_channel()
|
|
376
|
+
exchange = await channel.declare_exchange(
|
|
377
|
+
self.NOTIFY_EXCHANGE, aio_pika.ExchangeType.FANOUT,
|
|
378
|
+
)
|
|
379
|
+
await exchange.publish(
|
|
380
|
+
aio_pika.Message(task_id.encode()),
|
|
381
|
+
routing_key="",
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
async def subscribe_results(self) -> AsyncIterator[str]:
|
|
385
|
+
channel = await self._new_channel()
|
|
386
|
+
try:
|
|
387
|
+
exchange = await channel.declare_exchange(
|
|
388
|
+
self.NOTIFY_EXCHANGE, aio_pika.ExchangeType.FANOUT,
|
|
389
|
+
)
|
|
390
|
+
queue = await channel.declare_queue(exclusive=True)
|
|
391
|
+
await queue.bind(exchange)
|
|
392
|
+
async with queue.iterator() as qi:
|
|
393
|
+
async for message in qi:
|
|
394
|
+
async with message.process():
|
|
395
|
+
yield message.body.decode()
|
|
396
|
+
finally:
|
|
397
|
+
with contextlib.suppress(Exception):
|
|
398
|
+
await channel.close()
|
|
399
|
+
|
|
400
|
+
# -- Lifecycle -------------------------------------------------------------
|
|
401
|
+
|
|
402
|
+
async def close(self) -> None:
|
|
403
|
+
if self._channel is not None:
|
|
404
|
+
with contextlib.suppress(Exception):
|
|
405
|
+
await self._channel.close()
|
|
406
|
+
self._channel = None
|
|
407
|
+
if self._connection is not None:
|
|
408
|
+
with contextlib.suppress(Exception):
|
|
409
|
+
await self._connection.close()
|
|
410
|
+
self._connection = None
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
"""Redis-backed transport using ``RPUSH``/``BLPOP`` for tasks and results."""
|
|
2
|
+
|
|
3
|
+
import time
|
|
4
|
+
import asyncio
|
|
5
|
+
from typing import Any
|
|
6
|
+
from collections.abc import AsyncIterator
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import redis.asyncio as _redis
|
|
10
|
+
except ImportError:
|
|
11
|
+
raise ImportError(
|
|
12
|
+
"redis package is required for RedisBackend. "
|
|
13
|
+
"Install it with: pip install redis"
|
|
14
|
+
) from None
|
|
15
|
+
|
|
16
|
+
from offwork.worker.backends.base import Backend
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class RedisBackend(Backend):
|
|
20
|
+
"""Redis-backed transport using ``RPUSH``/``BLPOP`` for tasks and results.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
url
|
|
25
|
+
Redis connection URL (e.g. ``redis://localhost:6379``).
|
|
26
|
+
queue_key
|
|
27
|
+
Redis key for the task queue.
|
|
28
|
+
result_ttl
|
|
29
|
+
Seconds before result keys expire.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
DEFAULT_QUEUE_KEY = "offwork:tasks"
|
|
33
|
+
RESULT_PREFIX = "offwork:result:"
|
|
34
|
+
HEARTBEAT_PREFIX = "offwork:heartbeat:"
|
|
35
|
+
CANCEL_PREFIX = "offwork:cancel:"
|
|
36
|
+
PROGRESS_PREFIX = "offwork:progress:"
|
|
37
|
+
SCHEDULE_PREFIX = "offwork:schedule:"
|
|
38
|
+
THROTTLE_PREFIX = "offwork:throttle:"
|
|
39
|
+
NOTIFY_CHANNEL = "offwork:notify"
|
|
40
|
+
DEFAULT_RESULT_TTL = 300
|
|
41
|
+
HEARTBEAT_TTL = 30
|
|
42
|
+
CANCEL_TTL = 3600
|
|
43
|
+
PROGRESS_TTL = 300
|
|
44
|
+
SCHEDULE_TTL = 2592000 # 30 days
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
url: str = "redis://localhost:6379",
|
|
49
|
+
*,
|
|
50
|
+
queue_key: str | None = None,
|
|
51
|
+
result_ttl: int | None = None,
|
|
52
|
+
) -> None:
|
|
53
|
+
self._redis: Any = _redis.Redis.from_url(url)
|
|
54
|
+
self._queue_key = queue_key or self.DEFAULT_QUEUE_KEY
|
|
55
|
+
self._result_ttl = result_ttl or self.DEFAULT_RESULT_TTL
|
|
56
|
+
|
|
57
|
+
async def submit(self, task_json: str) -> None:
|
|
58
|
+
await self._redis.rpush(self._queue_key, task_json)
|
|
59
|
+
|
|
60
|
+
async def listen(self) -> AsyncIterator[str]:
|
|
61
|
+
"""Block on ``BLPOP`` and yield task JSON strings as they arrive."""
|
|
62
|
+
while True:
|
|
63
|
+
result = await self._redis.blpop(self._queue_key)
|
|
64
|
+
if result is None:
|
|
65
|
+
continue
|
|
66
|
+
_, raw = result
|
|
67
|
+
yield raw.decode() if isinstance(raw, bytes) else raw
|
|
68
|
+
|
|
69
|
+
async def send_result(self, task_id: str, result_json: str) -> None:
|
|
70
|
+
key = f"{self.RESULT_PREFIX}{task_id}"
|
|
71
|
+
await self._redis.rpush(key, result_json)
|
|
72
|
+
await self._redis.expire(key, self._result_ttl)
|
|
73
|
+
|
|
74
|
+
async def get_result(self, task_id: str, timeout: float | None = None) -> str:
|
|
75
|
+
key = f"{self.RESULT_PREFIX}{task_id}"
|
|
76
|
+
t = int(timeout) if timeout else 0
|
|
77
|
+
result = await self._redis.blpop(key, timeout=t)
|
|
78
|
+
if result is None:
|
|
79
|
+
raise TimeoutError(
|
|
80
|
+
f"Timed out waiting for result of task {task_id}"
|
|
81
|
+
)
|
|
82
|
+
_, raw = result
|
|
83
|
+
return raw.decode() if isinstance(raw, bytes) else raw
|
|
84
|
+
|
|
85
|
+
async def try_get_result(self, task_id: str) -> str | None:
|
|
86
|
+
"""Non-blocking ``LPOP``; returns ``None`` if not yet available."""
|
|
87
|
+
key = f"{self.RESULT_PREFIX}{task_id}"
|
|
88
|
+
raw = await self._redis.lpop(key)
|
|
89
|
+
if raw is None:
|
|
90
|
+
return None
|
|
91
|
+
return raw.decode() if isinstance(raw, bytes) else raw
|
|
92
|
+
|
|
93
|
+
async def send_heartbeat(self, task_id: str) -> None:
|
|
94
|
+
key = f"{self.HEARTBEAT_PREFIX}{task_id}"
|
|
95
|
+
await self._redis.set(key, str(time.time()), ex=self.HEARTBEAT_TTL)
|
|
96
|
+
|
|
97
|
+
async def get_heartbeat(self, task_id: str) -> float | None:
|
|
98
|
+
key = f"{self.HEARTBEAT_PREFIX}{task_id}"
|
|
99
|
+
raw = await self._redis.get(key)
|
|
100
|
+
if raw is None:
|
|
101
|
+
return None
|
|
102
|
+
return float(raw)
|
|
103
|
+
|
|
104
|
+
async def get_heartbeats(self, task_ids: list[str]) -> dict[str, float | None]:
|
|
105
|
+
"""Batch fetch via ``MGET`` for efficiency."""
|
|
106
|
+
if not task_ids:
|
|
107
|
+
return {}
|
|
108
|
+
keys = [f"{self.HEARTBEAT_PREFIX}{tid}" for tid in task_ids]
|
|
109
|
+
values = await self._redis.mget(keys)
|
|
110
|
+
return {
|
|
111
|
+
tid: float(v) if v is not None else None
|
|
112
|
+
for tid, v in zip(task_ids, values)
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
async def cancel_task(self, task_id: str) -> None:
|
|
116
|
+
key = f"{self.CANCEL_PREFIX}{task_id}"
|
|
117
|
+
await self._redis.set(key, "1", ex=self.CANCEL_TTL)
|
|
118
|
+
|
|
119
|
+
async def is_cancelled(self, task_id: str) -> bool:
|
|
120
|
+
key = f"{self.CANCEL_PREFIX}{task_id}"
|
|
121
|
+
return int(await self._redis.exists(key)) > 0
|
|
122
|
+
|
|
123
|
+
async def send_progress(self, task_id: str, progress_json: str) -> None:
|
|
124
|
+
key = f"{self.PROGRESS_PREFIX}{task_id}"
|
|
125
|
+
await self._redis.set(key, progress_json, ex=self.PROGRESS_TTL)
|
|
126
|
+
|
|
127
|
+
async def get_progress(self, task_id: str) -> str | None:
|
|
128
|
+
key = f"{self.PROGRESS_PREFIX}{task_id}"
|
|
129
|
+
raw = await self._redis.get(key)
|
|
130
|
+
if raw is None:
|
|
131
|
+
return None
|
|
132
|
+
return raw.decode() if isinstance(raw, bytes) else raw
|
|
133
|
+
|
|
134
|
+
async def cancel_schedule(self, schedule_id: str) -> None:
|
|
135
|
+
key = f"{self.SCHEDULE_PREFIX}{schedule_id}"
|
|
136
|
+
await self._redis.set(key, "1", ex=self.SCHEDULE_TTL)
|
|
137
|
+
|
|
138
|
+
async def is_schedule_cancelled(self, schedule_id: str) -> bool:
|
|
139
|
+
key = f"{self.SCHEDULE_PREFIX}{schedule_id}"
|
|
140
|
+
return int(await self._redis.exists(key)) > 0
|
|
141
|
+
|
|
142
|
+
async def check_throttle(self, function_name: str) -> bool:
|
|
143
|
+
key = f"{self.THROTTLE_PREFIX}{function_name}"
|
|
144
|
+
return int(await self._redis.exists(key)) == 0
|
|
145
|
+
|
|
146
|
+
async def record_throttle(
|
|
147
|
+
self, function_name: str, throttle_seconds: float,
|
|
148
|
+
) -> None:
|
|
149
|
+
key = f"{self.THROTTLE_PREFIX}{function_name}"
|
|
150
|
+
await self._redis.set(key, "1", ex=max(1, int(throttle_seconds)))
|
|
151
|
+
|
|
152
|
+
async def notify_result(self, task_id: str) -> None:
|
|
153
|
+
"""Publish task_id on the Pub/Sub notification channel."""
|
|
154
|
+
await self._redis.publish(self.NOTIFY_CHANNEL, task_id)
|
|
155
|
+
|
|
156
|
+
async def subscribe_results(self) -> AsyncIterator[str]:
|
|
157
|
+
"""Subscribe to the Pub/Sub channel and yield task IDs on result arrival."""
|
|
158
|
+
pubsub = self._redis.pubsub()
|
|
159
|
+
await pubsub.subscribe(self.NOTIFY_CHANNEL)
|
|
160
|
+
try:
|
|
161
|
+
while True:
|
|
162
|
+
msg = await pubsub.get_message(
|
|
163
|
+
ignore_subscribe_messages=True, timeout=1.0,
|
|
164
|
+
)
|
|
165
|
+
if msg is not None and msg["type"] == "message":
|
|
166
|
+
data = msg["data"]
|
|
167
|
+
yield data.decode() if isinstance(data, bytes) else data
|
|
168
|
+
elif msg is None:
|
|
169
|
+
await asyncio.sleep(0.01)
|
|
170
|
+
finally:
|
|
171
|
+
await pubsub.unsubscribe(self.NOTIFY_CHANNEL)
|
|
172
|
+
await pubsub.aclose()
|
|
173
|
+
|
|
174
|
+
async def close(self) -> None:
|
|
175
|
+
await self._redis.aclose()
|