offwork 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. offwork/__init__.py +167 -0
  2. offwork/__main__.py +770 -0
  3. offwork/_venv.py +174 -0
  4. offwork/core/__init__.py +15 -0
  5. offwork/core/errors.py +83 -0
  6. offwork/core/models.py +174 -0
  7. offwork/core/pairing.py +389 -0
  8. offwork/core/progress.py +91 -0
  9. offwork/core/signing.py +91 -0
  10. offwork/core/task.py +520 -0
  11. offwork/core/token.py +184 -0
  12. offwork/core/version.py +10 -0
  13. offwork/graph/__init__.py +5 -0
  14. offwork/graph/analyzer.py +637 -0
  15. offwork/graph/decorator.py +87 -0
  16. offwork/graph/graph.py +995 -0
  17. offwork/graph/store.py +500 -0
  18. offwork/graph/tracing.py +429 -0
  19. offwork/py.typed +0 -0
  20. offwork/typing.py +48 -0
  21. offwork/worker/__init__.py +18 -0
  22. offwork/worker/backends/__init__.py +3 -0
  23. offwork/worker/backends/base.py +149 -0
  24. offwork/worker/backends/http.py +237 -0
  25. offwork/worker/backends/local.py +452 -0
  26. offwork/worker/backends/rabbitmq.py +410 -0
  27. offwork/worker/backends/redis.py +175 -0
  28. offwork/worker/deps.py +365 -0
  29. offwork/worker/remote.py +793 -0
  30. offwork/worker/result.py +276 -0
  31. offwork/worker/sandbox/Dockerfile +24 -0
  32. offwork/worker/sandbox/__init__.py +18 -0
  33. offwork/worker/sandbox/_protocol.py +50 -0
  34. offwork/worker/sandbox/docker.py +438 -0
  35. offwork/worker/sandbox/guest_agent.py +622 -0
  36. offwork/worker/schedule.py +26 -0
  37. offwork/worker/worker.py +263 -0
  38. offwork-0.4.0.dist-info/METADATA +143 -0
  39. offwork-0.4.0.dist-info/RECORD +42 -0
  40. offwork-0.4.0.dist-info/WHEEL +4 -0
  41. offwork-0.4.0.dist-info/entry_points.txt +3 -0
  42. offwork-0.4.0.dist-info/licenses/LICENSE +661 -0
@@ -0,0 +1,793 @@
1
+ """Remote execution orchestration: connect, serve, and submit tasks."""
2
+
3
+ import os
4
+ import sys
5
+ import json
6
+ import time
7
+ import uuid
8
+ import atexit
9
+ import signal
10
+ import asyncio
11
+ import inspect
12
+ import logging
13
+ import contextlib
14
+ from typing import TYPE_CHECKING, Any
15
+ from collections.abc import Callable, Awaitable
16
+
17
+ from offwork.core.task import Task
18
+ from offwork.core.token import resolve_signing_key
19
+ from offwork.core.version import _VERSION
20
+ from offwork.core.progress import _progress_callback
21
+ from offwork.worker.result import Result, ResultEnvelope
22
+ from offwork.worker.worker import Worker
23
+ from offwork.worker.schedule import ScheduleHandle
24
+ from offwork.worker.backends.base import Backend
25
+
26
+ if TYPE_CHECKING:
27
+ from offwork.worker.sandbox import DockerSandbox
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ _active_backend: Backend | None = None
32
+ _atexit_registered = False
33
+
34
+ _ENV_VAR = "OFFWORK_BACKEND"
35
+
36
+
37
+ def _resolve_url(url: str | None) -> str:
38
+ """Return *url* if given, otherwise read from the environment variable."""
39
+ if url is not None:
40
+ return url
41
+ env_url = os.environ.get(_ENV_VAR)
42
+ if env_url:
43
+ return env_url
44
+ raise ValueError(
45
+ "No backend URL provided. Pass a URL or set the "
46
+ f"{_ENV_VAR} environment variable."
47
+ )
48
+
49
+
50
+ def _create_backend(url: str, **kwargs: Any) -> Backend:
51
+ """Create a backend instance from a URL."""
52
+ scheme = url.split("://", 1)[0].lower()
53
+ if scheme in ("redis", "rediss"):
54
+ from offwork.worker.backends.redis import RedisBackend
55
+
56
+ return RedisBackend(url, **kwargs)
57
+ if scheme == "local":
58
+ from offwork.worker.backends.local import LocalBackend
59
+
60
+ return LocalBackend(url, **kwargs)
61
+ if scheme in ("amqp", "amqps"):
62
+ from offwork.worker.backends.rabbitmq import RabbitMQBackend
63
+
64
+ return RabbitMQBackend(url, **kwargs)
65
+ if scheme in ("http", "https"):
66
+ from offwork.worker.backends.http import HttpBackend
67
+
68
+ return HttpBackend(url)
69
+ raise ValueError(
70
+ f"Unknown backend scheme: {scheme!r}. "
71
+ f"Supported: redis://, rediss://, local://, amqp://, amqps://, http://, https://"
72
+ )
73
+
74
+
75
+ def connect(url: str | None = None, **kwargs: Any) -> Backend:
76
+ """Configure the global transport backend.
77
+
78
+ Parameters
79
+ ----------
80
+ url
81
+ Backend URL. Supported schemes:
82
+
83
+ - ``redis://`` / ``rediss://`` -- :class:`RedisBackend`
84
+ - ``local://`` -- :class:`LocalBackend` (same-machine IPC)
85
+ - ``amqp://`` / ``amqps://`` -- :class:`RabbitMQBackend`
86
+ - ``http://`` / ``https://`` -- :class:`HttpBackend`
87
+
88
+ When *None*, the ``OFFWORK_BACKEND`` environment variable is used.
89
+
90
+ **kwargs
91
+ Passed to the backend constructor.
92
+
93
+ Returns
94
+ -------
95
+ Backend
96
+ The connected backend instance.
97
+ """
98
+ global _active_backend, _atexit_registered
99
+ resolved = _resolve_url(url)
100
+ _active_backend = _create_backend(resolved, **kwargs)
101
+ if not _atexit_registered:
102
+ atexit.register(_sync_disconnect)
103
+ _atexit_registered = True
104
+ logger.debug("Connected to backend: %s", resolved)
105
+ return _active_backend
106
+
107
+
108
+ async def disconnect() -> None:
109
+ """Close and clear the global backend."""
110
+ global _active_backend
111
+ if _active_backend is not None:
112
+ await _active_backend.close()
113
+ _active_backend = None
114
+ logger.info("Disconnected from backend")
115
+
116
+
117
+ def _sync_disconnect() -> None:
118
+ """Synchronous atexit handler for disconnect."""
119
+ global _active_backend
120
+ if _active_backend is not None:
121
+ try:
122
+ asyncio.run(_active_backend.close())
123
+ except RuntimeError:
124
+ pass # event loop already closed
125
+ _active_backend = None
126
+
127
+
128
+ def get_backend() -> Backend:
129
+ """Return the active backend, or raise if none is configured.
130
+
131
+ If no backend has been configured via :func:`connect`, the
132
+ ``OFFWORK_BACKEND`` environment variable is checked and used
133
+ to auto-connect.
134
+ """
135
+ global _active_backend
136
+ if _active_backend is None:
137
+ env_url = os.environ.get(_ENV_VAR)
138
+ if env_url:
139
+ connect(env_url)
140
+ else:
141
+ raise RuntimeError(
142
+ "No backend connected. Call offwork.connect('redis://...') or offwork.connect('https://...') "
143
+ f"or set the {_ENV_VAR} environment variable."
144
+ )
145
+ return _active_backend # type: ignore[return-value]
146
+
147
+
148
+ async def submit_remote(
149
+ func: Callable[..., object],
150
+ wrapper: Callable[..., object],
151
+ *args: Any,
152
+ _backend: str | Backend | None = None,
153
+ _signing_key: bytes | None = None,
154
+ **kwargs: Any,
155
+ ) -> Result:
156
+ """Pack and submit a function to the remote backend.
157
+
158
+ Called internally by ``traced_func.run(...)``.
159
+
160
+ Parameters
161
+ ----------
162
+ _signing_key
163
+ When provided, the task JSON is HMAC-signed so the worker can
164
+ verify the origin. Typically loaded from disk after pairing.
165
+ """
166
+ from offwork.graph.graph import Graph # circular
167
+
168
+ if isinstance(_backend, str):
169
+ backend = _create_backend(_backend)
170
+ elif isinstance(_backend, Backend):
171
+ backend = _backend
172
+ else:
173
+ backend = get_backend()
174
+
175
+ # Auto-load client signing key if not explicitly provided
176
+ if _signing_key is None:
177
+ _signing_key = resolve_signing_key("client")
178
+
179
+ unwrapped = inspect.unwrap(func)
180
+ function_name = f"{unwrapped.__module__}.{unwrapped.__qualname__}"
181
+ logger.debug("Serializing graph for %s", function_name)
182
+ graph_json = Graph.default().serialize(wrapper)
183
+
184
+ opts = getattr(wrapper, "__offwork_options__", {})
185
+ task = Task(
186
+ graph_json=graph_json,
187
+ function_name=function_name,
188
+ args=args,
189
+ kwargs=kwargs,
190
+ timeout=opts.get("timeout"),
191
+ retries=opts.get("retries", 0),
192
+ retry_delay=opts.get("retry_delay", 1.0),
193
+ throttle=opts.get("throttle"),
194
+ )
195
+
196
+ logger.debug("Submitting task %s → %s", task.task_id[:8], function_name)
197
+ await backend.submit(task.to_json(signing_key=_signing_key))
198
+ logger.info("Submitted task %s for %s", task.task_id, function_name)
199
+ return Result(task.task_id, backend)
200
+
201
+
202
+ async def submit_remote_scheduled(
203
+ func: Callable[..., object],
204
+ wrapper: Callable[..., object],
205
+ *args: Any,
206
+ _backend: str | Backend | None = None,
207
+ _signing_key: bytes | None = None,
208
+ _scheduled_at: float | None = None,
209
+ **kwargs: Any,
210
+ ) -> Result:
211
+ """Submit a task scheduled for future execution."""
212
+ from offwork.graph.graph import Graph # circular
213
+
214
+ if isinstance(_backend, str):
215
+ backend = _create_backend(_backend)
216
+ elif isinstance(_backend, Backend):
217
+ backend = _backend
218
+ else:
219
+ backend = get_backend()
220
+
221
+ if _signing_key is None:
222
+ _signing_key = resolve_signing_key("client")
223
+
224
+ unwrapped = inspect.unwrap(func)
225
+ function_name = f"{unwrapped.__module__}.{unwrapped.__qualname__}"
226
+ logger.debug("Serializing graph for %s", function_name)
227
+ graph_json = Graph.default().serialize(wrapper)
228
+
229
+ opts = getattr(wrapper, "__offwork_options__", {})
230
+ task = Task(
231
+ graph_json=graph_json,
232
+ function_name=function_name,
233
+ args=args,
234
+ kwargs=kwargs,
235
+ timeout=opts.get("timeout"),
236
+ retries=opts.get("retries", 0),
237
+ retry_delay=opts.get("retry_delay", 1.0),
238
+ throttle=opts.get("throttle"),
239
+ scheduled_at=_scheduled_at,
240
+ )
241
+
242
+ logger.debug(
243
+ "Submitting scheduled task %s → %s (at %.3f)",
244
+ task.task_id[:8], function_name, _scheduled_at or 0,
245
+ )
246
+ await backend.submit(task.to_json(signing_key=_signing_key))
247
+ logger.info(
248
+ "Submitted scheduled task %s for %s (at %.0f)",
249
+ task.task_id, function_name, _scheduled_at or 0,
250
+ )
251
+ return Result(task.task_id, backend)
252
+
253
+
254
+ async def submit_recurring(
255
+ func: Callable[..., object],
256
+ wrapper: Callable[..., object],
257
+ *args: Any,
258
+ _backend: str | Backend | None = None,
259
+ _signing_key: bytes | None = None,
260
+ _interval: float = 0,
261
+ _start_at: float | None = None,
262
+ **kwargs: Any,
263
+ ) -> ScheduleHandle:
264
+ """Submit a recurring task and return a :class:`ScheduleHandle`."""
265
+ from offwork.graph.graph import Graph # circular
266
+
267
+ if isinstance(_backend, str):
268
+ backend = _create_backend(_backend)
269
+ elif isinstance(_backend, Backend):
270
+ backend = _backend
271
+ else:
272
+ backend = get_backend()
273
+
274
+ if _signing_key is None:
275
+ _signing_key = resolve_signing_key("client")
276
+
277
+ unwrapped = inspect.unwrap(func)
278
+ function_name = f"{unwrapped.__module__}.{unwrapped.__qualname__}"
279
+ logger.debug("Serializing graph for %s", function_name)
280
+ graph_json = Graph.default().serialize(wrapper)
281
+
282
+ schedule_id = uuid.uuid4().hex[:12]
283
+ scheduled_at = _start_at or time.time()
284
+
285
+ opts = getattr(wrapper, "__offwork_options__", {})
286
+ task = Task(
287
+ graph_json=graph_json,
288
+ function_name=function_name,
289
+ args=args,
290
+ kwargs=kwargs,
291
+ timeout=opts.get("timeout"),
292
+ retries=opts.get("retries", 0),
293
+ retry_delay=opts.get("retry_delay", 1.0),
294
+ throttle=opts.get("throttle"),
295
+ scheduled_at=scheduled_at,
296
+ recur_interval=_interval,
297
+ schedule_id=schedule_id,
298
+ )
299
+
300
+ logger.debug(
301
+ "Submitting recurring task %s → %s (every %.1fs, schedule=%s)",
302
+ task.task_id[:8], function_name, _interval, schedule_id,
303
+ )
304
+ await backend.submit(task.to_json(signing_key=_signing_key))
305
+ logger.info(
306
+ "Submitted recurring task %s for %s (every %.1fs, schedule=%s)",
307
+ task.task_id, function_name, _interval, schedule_id,
308
+ )
309
+ return ScheduleHandle(schedule_id, backend)
310
+
311
+
312
+ def _build_detail_tags(worker: Worker) -> str:
313
+ """Build a comma-separated detail string from the last build info."""
314
+ build_info = worker.last_build_info()
315
+ if build_info is not None and build_info.cache_hit:
316
+ parts = ["cached"]
317
+ else:
318
+ parts = ["build"]
319
+ if build_info is not None and build_info.installed_packages:
320
+ parts.append("pip " + " ".join(build_info.installed_packages))
321
+ return ", ".join(parts)
322
+
323
+
324
+ _HEARTBEAT_INTERVAL = 1.0
325
+
326
+
327
+ async def _heartbeat_loop(
328
+ backend: Backend,
329
+ task_id: str,
330
+ cancel_event: asyncio.Event,
331
+ exec_task: asyncio.Task[Any] | None = None,
332
+ ) -> None:
333
+ """Send periodic heartbeats and check for cancellation.
334
+
335
+ When *exec_task* is provided and the backend reports the task as
336
+ cancelled, the execution task is cancelled via
337
+ :meth:`asyncio.Task.cancel`, which raises :class:`CancelledError`
338
+ at the next ``await`` in async user functions.
339
+ """
340
+ while not cancel_event.is_set():
341
+ try:
342
+ await backend.send_heartbeat(task_id)
343
+ except Exception:
344
+ logger.debug("Heartbeat send failed for task %s", task_id, exc_info=True)
345
+ if exec_task is not None:
346
+ try:
347
+ if await backend.is_cancelled(task_id):
348
+ exec_task.cancel()
349
+ return
350
+ except Exception:
351
+ logger.debug("Cancellation check failed for task %s", task_id, exc_info=True)
352
+ try:
353
+ await asyncio.wait_for(cancel_event.wait(), timeout=_HEARTBEAT_INTERVAL)
354
+ except asyncio.TimeoutError:
355
+ pass
356
+
357
+
358
+ _PROGRESS_MIN_INTERVAL = 0.05 # 50ms rate limit
359
+
360
+
361
+ def _make_progress_callback(
362
+ backend: Backend,
363
+ task_id: str,
364
+ loop: asyncio.AbstractEventLoop,
365
+ ) -> tuple[
366
+ Callable[[float, float | None, str | None], None],
367
+ Callable[[], Awaitable[None]],
368
+ ]:
369
+ """Create a rate-limited progress callback.
370
+
371
+ Returns ``(callback, flush)``. The callback stores the latest
372
+ progress locally and only sends to the backend when at least 50 ms
373
+ have elapsed since the last send. Call ``await flush()`` after
374
+ execution to guarantee the final state is delivered.
375
+ """
376
+ state: dict[str, Any] = {
377
+ "latest": None, # latest progress dict (always kept)
378
+ "last_sent": 0.0, # monotonic time of last send
379
+ "task": None, # most recent fire-and-forget Task
380
+ "flushed": False, # set by flush() to block late sends
381
+ }
382
+
383
+ async def _send(data_json: str) -> None:
384
+ try:
385
+ await backend.send_progress(task_id, data_json)
386
+ except Exception:
387
+ logger.debug("Progress send failed for task %s", task_id, exc_info=True)
388
+
389
+ def _do_send(data_json: str) -> None:
390
+ if state["flushed"]:
391
+ return
392
+ state["task"] = asyncio.create_task(_send(data_json))
393
+
394
+ def _on_progress(
395
+ current: float,
396
+ total: float | None = None,
397
+ message: str | None = None,
398
+ ) -> None:
399
+ if state["flushed"]:
400
+ return
401
+ d: dict[str, Any] = {"current": current}
402
+ if total is not None:
403
+ d["total"] = total
404
+ if message is not None:
405
+ d["message"] = message
406
+ state["latest"] = d
407
+
408
+ now = time.monotonic()
409
+ if now - state["last_sent"] >= _PROGRESS_MIN_INTERVAL:
410
+ data_json = json.dumps(d, separators=(",", ":"))
411
+ state["last_sent"] = now
412
+ try:
413
+ asyncio.get_running_loop()
414
+ _do_send(data_json)
415
+ except RuntimeError:
416
+ loop.call_soon_threadsafe(_do_send, data_json)
417
+
418
+ async def _flush() -> None:
419
+ state["flushed"] = True
420
+ # Do not cancel an in-flight backend RPC on the shared channel.
421
+ # Wait for it, then send the authoritative final state.
422
+ t: asyncio.Task[None] | None = state.get("task")
423
+ if t is not None:
424
+ with contextlib.suppress(asyncio.CancelledError):
425
+ await t
426
+ # Always send the authoritative final state
427
+ if state["latest"] is not None:
428
+ data_json = json.dumps(state["latest"], separators=(",", ":"))
429
+ await _send(data_json)
430
+
431
+ return _on_progress, _flush
432
+
433
+
434
+ def _log_task_result(
435
+ task: Task,
436
+ envelope: ResultEnvelope,
437
+ elapsed_ms: float,
438
+ worker: Worker,
439
+ ) -> None:
440
+ """Log the outcome of a completed task."""
441
+ short_id = task.task_id[:8]
442
+ details = _build_detail_tags(worker)
443
+ if envelope.status == "ok":
444
+ logger.info(
445
+ "\u2713 %-40s %6.0fms %s %s",
446
+ task.function_name, elapsed_ms, short_id, details,
447
+ )
448
+ elif envelope.status == "cancelled":
449
+ logger.info(
450
+ "\u2718 %-40s %s cancelled",
451
+ task.function_name, short_id,
452
+ )
453
+ else:
454
+ error_msg = f" {envelope.error_type}: {envelope.error_message}"
455
+ logger.warning(
456
+ "\u2717 %-40s %6.0fms %s %s%s",
457
+ task.function_name, elapsed_ms, short_id, details, error_msg,
458
+ )
459
+
460
+
461
+ async def _handle_task(
462
+ worker: Worker,
463
+ backend: Backend,
464
+ task_json: str,
465
+ signing_key: bytes | None = None,
466
+ ) -> None:
467
+ """Process a single task: deserialize, execute with policy, send result.
468
+
469
+ Parameters
470
+ ----------
471
+ signing_key
472
+ When provided, the task must carry a valid HMAC-SHA256 signature.
473
+ Unsigned or mis-signed tasks are rejected with an error result.
474
+ """
475
+ try:
476
+ task = Task.from_json(task_json, signing_key=signing_key)
477
+ except Exception as exc:
478
+ # If we can extract a task_id, send an error envelope so the
479
+ # client gets feedback instead of hanging forever.
480
+ logger.warning("Task rejected: %s", exc)
481
+ try:
482
+ data = json.loads(task_json)
483
+ task_id = data.get("id", "unknown")
484
+ except Exception:
485
+ task_id = "unknown"
486
+ if task_id != "unknown":
487
+ envelope = ResultEnvelope.failure(task_id, exc)
488
+ await backend.send_result(task_id, envelope.to_json())
489
+ await backend.notify_result(task_id)
490
+ return
491
+
492
+ logger.debug("Received task %s: %s", task.task_id, task.function_name)
493
+
494
+ # Wait for scheduled time
495
+ if task.scheduled_at is not None:
496
+ delay = task.scheduled_at - time.time()
497
+ if delay > 0:
498
+ logger.debug("Task %s scheduled in %.1fs", task.task_id, delay)
499
+ await asyncio.sleep(delay)
500
+
501
+ # Any failure in the backend checks below must still surface to the
502
+ # client, otherwise it would hang forever polling for a result.
503
+ try:
504
+ cancelled = await backend.is_cancelled(task.task_id)
505
+ except Exception as exc:
506
+ logger.exception("is_cancelled failed for task %s", task.task_id)
507
+ envelope = ResultEnvelope.failure(task.task_id, exc)
508
+ await backend.send_result(task.task_id, envelope.to_json())
509
+ await backend.notify_result(task.task_id)
510
+ return
511
+
512
+ if cancelled:
513
+ envelope = ResultEnvelope.cancelled(task.task_id)
514
+ _log_task_result(task, envelope, 0, worker)
515
+ return
516
+
517
+ # Check throttle
518
+ if task.throttle is not None:
519
+ try:
520
+ allowed = await backend.check_throttle(task.function_name)
521
+ except Exception as exc:
522
+ logger.exception(
523
+ "check_throttle failed for task %s (%s)",
524
+ task.task_id, task.function_name,
525
+ )
526
+ envelope = ResultEnvelope.failure(task.task_id, exc)
527
+ await backend.send_result(task.task_id, envelope.to_json())
528
+ await backend.notify_result(task.task_id)
529
+ return
530
+
531
+ if not allowed:
532
+ envelope = ResultEnvelope.throttled(task.task_id)
533
+ await backend.send_result(task.task_id, envelope.to_json())
534
+ await backend.notify_result(task.task_id)
535
+ logger.info(
536
+ "%-40s %s throttled",
537
+ task.function_name, task.task_id[:8],
538
+ )
539
+ return
540
+
541
+ # Set up rate-limited progress callback
542
+ loop = asyncio.get_running_loop()
543
+ progress_cb, flush = _make_progress_callback(backend, task.task_id, loop)
544
+ token = _progress_callback.set(progress_cb)
545
+
546
+ # Run execution as a task so the heartbeat loop can cancel it
547
+ exec_task: asyncio.Task[Any] = asyncio.create_task(worker.run_with_policy(task))
548
+
549
+ cancel_event = asyncio.Event()
550
+ hb_task = asyncio.create_task(
551
+ _heartbeat_loop(backend, task.task_id, cancel_event, exec_task),
552
+ )
553
+
554
+ t0 = time.monotonic()
555
+ try:
556
+ result = await exec_task
557
+ envelope = ResultEnvelope.success(task.task_id, result)
558
+ except asyncio.CancelledError:
559
+ envelope = ResultEnvelope.cancelled(task.task_id)
560
+ except Exception as exc:
561
+ logger.debug("Task %s failed", task.task_id, exc_info=True)
562
+ envelope = ResultEnvelope.failure(task.task_id, exc)
563
+ finally:
564
+ _progress_callback.reset(token)
565
+ cancel_event.set()
566
+ # Do not hb_task.cancel() — cancelling can interrupt an
567
+ # in-flight AMQP RPC (e.g. Queue.Declare) on the shared
568
+ # channel, causing it to close and preventing send_result
569
+ # from delivering the result. The cancel_event already
570
+ # signals the loop to exit promptly.
571
+ with contextlib.suppress(asyncio.CancelledError):
572
+ await hb_task
573
+
574
+ elapsed_ms = (time.monotonic() - t0) * 1000
575
+
576
+ # Flush any pending progress, then send the result unconditionally.
577
+ # If the client cancelled mid-execution, it already stored a cancelled
578
+ # result envelope (via Result.cancel) which the client reads first.
579
+ await flush()
580
+ try:
581
+ result_json = envelope.to_json()
582
+ except Exception as exc:
583
+ # Result serialization failed (e.g. unsupported return type).
584
+ # Surface the failure as an error envelope rather than letting
585
+ # the task hang forever from the client's point of view.
586
+ logger.exception(
587
+ "Failed to serialize result for task %s", task.task_id,
588
+ )
589
+ envelope = ResultEnvelope.failure(task.task_id, exc)
590
+ result_json = envelope.to_json()
591
+ await backend.send_result(task.task_id, result_json)
592
+ await backend.notify_result(task.task_id)
593
+
594
+ _log_task_result(task, envelope, elapsed_ms, worker)
595
+
596
+ # Record throttle cooldown after successful execution
597
+ if task.throttle is not None and envelope.status == "ok":
598
+ await backend.record_throttle(task.function_name, task.throttle)
599
+
600
+ # Re-enqueue recurring task
601
+ if task.recur_interval is not None and task.schedule_id is not None:
602
+ if not await backend.is_schedule_cancelled(task.schedule_id):
603
+ next_task = Task(
604
+ graph_json=task.graph_json,
605
+ function_name=task.function_name,
606
+ args=task.args,
607
+ kwargs=task.kwargs,
608
+ timeout=task.timeout,
609
+ retries=task.retries,
610
+ retry_delay=task.retry_delay,
611
+ throttle=task.throttle,
612
+ scheduled_at=time.time() + task.recur_interval,
613
+ recur_interval=task.recur_interval,
614
+ schedule_id=task.schedule_id,
615
+ )
616
+ await backend.submit(next_task.to_json(signing_key=signing_key))
617
+ logger.debug(
618
+ "Re-enqueued recurring task %s (schedule=%s, next in %.1fs)",
619
+ next_task.task_id, task.schedule_id, task.recur_interval,
620
+ )
621
+
622
+
623
+ async def _worker_loop(
624
+ worker: Worker,
625
+ backend: Backend,
626
+ concurrency: int,
627
+ signing_key: bytes | None = None,
628
+ ) -> None:
629
+ """Consume tasks from *backend* and dispatch to *worker*.
630
+
631
+ Supports graceful shutdown: on the first SIGINT/SIGTERM, stops
632
+ accepting new tasks and waits for in-progress tasks to complete.
633
+ On the second signal, cancels all in-progress tasks immediately.
634
+
635
+ Parameters
636
+ ----------
637
+ signing_key
638
+ When provided, every incoming task must carry a valid signature.
639
+ """
640
+ shutdown = asyncio.Event()
641
+ pending: set[asyncio.Task[None]] = set()
642
+ sem = asyncio.Semaphore(concurrency)
643
+
644
+ _got_first_signal = False
645
+
646
+ def _on_shutdown_signal() -> None:
647
+ nonlocal _got_first_signal
648
+ if not _got_first_signal:
649
+ _got_first_signal = True
650
+ shutdown.set()
651
+ else:
652
+ logger.warning("Forced shutdown — cancelling %d task(s).", len(pending))
653
+ for t in pending:
654
+ t.cancel()
655
+
656
+ # Install signal handlers (not available on Windows)
657
+ loop = asyncio.get_running_loop()
658
+ _signals_installed = False
659
+ for sig in (signal.SIGINT, signal.SIGTERM):
660
+ try:
661
+ loop.add_signal_handler(sig, _on_shutdown_signal)
662
+ _signals_installed = True
663
+ except (NotImplementedError, RuntimeError):
664
+ pass
665
+
666
+ async def bounded_handle(task_json: str) -> None:
667
+ async with sem:
668
+ await _handle_task(worker, backend, task_json, signing_key=signing_key)
669
+
670
+ async def _listen() -> None:
671
+ async for task_json in backend.listen():
672
+ if shutdown.is_set():
673
+ return
674
+ task = asyncio.create_task(bounded_handle(task_json))
675
+ pending.add(task)
676
+ task.add_done_callback(pending.discard)
677
+
678
+ listen_task = asyncio.create_task(_listen())
679
+
680
+ try:
681
+ await shutdown.wait()
682
+ except (KeyboardInterrupt, asyncio.CancelledError):
683
+ shutdown.set()
684
+
685
+ # Stop accepting new tasks
686
+ listen_task.cancel()
687
+ with contextlib.suppress(asyncio.CancelledError):
688
+ await listen_task
689
+
690
+ # Wait for in-flight tasks to complete
691
+ if pending:
692
+ logger.info(
693
+ "Graceful shutdown: waiting for %d task(s) to complete... "
694
+ "(Ctrl+C to force quit)",
695
+ len(pending),
696
+ )
697
+ try:
698
+ await asyncio.gather(*pending, return_exceptions=True)
699
+ except (KeyboardInterrupt, asyncio.CancelledError):
700
+ for t in pending:
701
+ t.cancel()
702
+ await asyncio.gather(*pending, return_exceptions=True)
703
+
704
+ # Clean up signal handlers
705
+ if _signals_installed:
706
+ for sig in (signal.SIGINT, signal.SIGTERM):
707
+ with contextlib.suppress(NotImplementedError, RuntimeError):
708
+ loop.remove_signal_handler(sig)
709
+
710
+
711
+ async def serve(
712
+ url: str | None = None,
713
+ *,
714
+ concurrency: int = 1,
715
+ auto_install: bool = True,
716
+ import_to_package: dict[str, str] | None = None,
717
+ sandbox: "DockerSandbox | bool | None" = None,
718
+ require_signing: bool = False,
719
+ ) -> None:
720
+ """Start a worker loop that pops tasks from the backend and executes them.
721
+
722
+ Parameters
723
+ ----------
724
+ url
725
+ Backend URL (e.g. ``redis://localhost:6379``).
726
+ When *None*, the ``OFFWORK_BACKEND`` environment variable is used.
727
+ concurrency
728
+ Number of concurrent tasks (default: 1).
729
+ auto_install
730
+ Automatically install missing third-party dependencies via pip.
731
+ import_to_package
732
+ Extra import-name to pip-package-name mappings.
733
+ sandbox
734
+ ``True`` or a :class:`~offwork.worker.sandbox.DockerSandbox`
735
+ instance to execute tasks inside a Docker container.
736
+ require_signing
737
+ When ``True``, only execute tasks that carry a valid HMAC
738
+ signature from a paired client. The shared key is loaded from
739
+ ``~/.offwork/worker.key`` (written by ``offwork pair``).
740
+ """
741
+ from offwork.worker.sandbox import DockerSandbox
742
+
743
+ resolved = _resolve_url(url)
744
+ auto_tag = "on" if auto_install else "off"
745
+ sandbox_tag = "docker" if sandbox else "off"
746
+ signing_tag = "on" if require_signing else "off"
747
+ logger.info(
748
+ "offwork worker v%s \u2502 %s \u2502 concurrency=%d \u2502 "
749
+ "auto_install=%s \u2502 sandbox=%s \u2502 signing=%s",
750
+ _VERSION, resolved, concurrency, auto_tag, sandbox_tag, signing_tag,
751
+ )
752
+
753
+ # Load signing key if required
754
+ signing_key: bytes | None = None
755
+ if require_signing:
756
+ signing_key = resolve_signing_key("worker")
757
+ if signing_key is None:
758
+ logger.error(
759
+ "Signing is enabled but no key material found. "
760
+ "Set OFFWORK_SIGNING_TOKEN, run 'offwork token generate', "
761
+ "or run 'offwork pair' to pair with a client."
762
+ )
763
+ sys.exit(1)
764
+ logger.info("Task signing enabled — only signed tasks will be executed")
765
+
766
+ try:
767
+ backend = connect(resolved)
768
+ except Exception as exc:
769
+ logger.error("Could not connect to %s: %s", resolved, exc)
770
+ sys.exit(1)
771
+
772
+ worker = Worker(
773
+ auto_install=auto_install,
774
+ import_to_package=import_to_package,
775
+ sandbox=sandbox,
776
+ )
777
+
778
+ # Boot the sandbox container before accepting tasks so the first
779
+ # execution doesn't pay the startup cost.
780
+ if worker.sandboxed:
781
+ assert worker._sandbox is not None
782
+ await worker._sandbox.start()
783
+
784
+ logger.info("Listening for tasks \u2014 Ctrl+C to stop.")
785
+
786
+ try:
787
+ await _worker_loop(worker, backend, concurrency, signing_key=signing_key)
788
+ finally:
789
+ if worker.sandboxed:
790
+ assert worker._sandbox is not None
791
+ await worker._sandbox.stop()
792
+ await disconnect()
793
+ logger.info("Worker stopped.")