offwork 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. offwork/__init__.py +167 -0
  2. offwork/__main__.py +770 -0
  3. offwork/_venv.py +174 -0
  4. offwork/core/__init__.py +15 -0
  5. offwork/core/errors.py +83 -0
  6. offwork/core/models.py +174 -0
  7. offwork/core/pairing.py +389 -0
  8. offwork/core/progress.py +91 -0
  9. offwork/core/signing.py +91 -0
  10. offwork/core/task.py +520 -0
  11. offwork/core/token.py +184 -0
  12. offwork/core/version.py +10 -0
  13. offwork/graph/__init__.py +5 -0
  14. offwork/graph/analyzer.py +637 -0
  15. offwork/graph/decorator.py +87 -0
  16. offwork/graph/graph.py +995 -0
  17. offwork/graph/store.py +500 -0
  18. offwork/graph/tracing.py +429 -0
  19. offwork/py.typed +0 -0
  20. offwork/typing.py +48 -0
  21. offwork/worker/__init__.py +18 -0
  22. offwork/worker/backends/__init__.py +3 -0
  23. offwork/worker/backends/base.py +149 -0
  24. offwork/worker/backends/http.py +237 -0
  25. offwork/worker/backends/local.py +452 -0
  26. offwork/worker/backends/rabbitmq.py +410 -0
  27. offwork/worker/backends/redis.py +175 -0
  28. offwork/worker/deps.py +365 -0
  29. offwork/worker/remote.py +793 -0
  30. offwork/worker/result.py +276 -0
  31. offwork/worker/sandbox/Dockerfile +24 -0
  32. offwork/worker/sandbox/__init__.py +18 -0
  33. offwork/worker/sandbox/_protocol.py +50 -0
  34. offwork/worker/sandbox/docker.py +438 -0
  35. offwork/worker/sandbox/guest_agent.py +622 -0
  36. offwork/worker/schedule.py +26 -0
  37. offwork/worker/worker.py +263 -0
  38. offwork-0.4.0.dist-info/METADATA +143 -0
  39. offwork-0.4.0.dist-info/RECORD +42 -0
  40. offwork-0.4.0.dist-info/WHEEL +4 -0
  41. offwork-0.4.0.dist-info/entry_points.txt +3 -0
  42. offwork-0.4.0.dist-info/licenses/LICENSE +661 -0
@@ -0,0 +1,237 @@
1
+ """HTTP(S) backend for hosted broker deployments."""
2
+
3
+ import json
4
+ import time
5
+ import base64
6
+ import asyncio
7
+ from typing import Any
8
+ from urllib.error import HTTPError, URLError
9
+ from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
10
+ from urllib.request import Request, urlopen
11
+ from collections.abc import AsyncIterator
12
+
13
+ from offwork.worker.backends.base import Backend
14
+
15
+ _DEFAULT_BROKER_PATH = "/api/v1/broker"
16
+ _DEFAULT_LONG_POLL_SECONDS = 30.0
17
+
18
+
19
+ class HttpBackend(Backend):
20
+ """HTTP(S)-based backend for hosted offwork brokers.
21
+
22
+ The base URL can point either at the broker root or at the service root;
23
+ when no path is provided, ``/api/v1/broker`` is assumed.
24
+
25
+ Authentication is intentionally simple for the proof-of-concept:
26
+ include ``?api_key=...`` in the URL and the backend will move it into the
27
+ ``X-Offwork-API-Key`` request header.
28
+ """
29
+
30
+ def __init__(self, url: str) -> None:
31
+ parsed = urlparse(url)
32
+ if parsed.scheme not in {"http", "https"}:
33
+ raise ValueError(f"Unsupported HTTP backend scheme: {parsed.scheme!r}")
34
+
35
+ query_items = parse_qsl(parsed.query, keep_blank_values=True)
36
+ api_key = ""
37
+ filtered_query: list[tuple[str, str]] = []
38
+ for key, value in query_items:
39
+ if key == "api_key" and not api_key:
40
+ api_key = value
41
+ continue
42
+ filtered_query.append((key, value))
43
+
44
+ path = parsed.path.rstrip("/") or _DEFAULT_BROKER_PATH
45
+ self._base_url = urlunparse(parsed._replace(path=path, query=urlencode(filtered_query)))
46
+ self._api_key = api_key or None
47
+
48
+ def _headers(self) -> dict[str, str]:
49
+ headers = {"Content-Type": "application/json"}
50
+ if self._api_key:
51
+ headers["X-Offwork-API-Key"] = self._api_key
52
+ return headers
53
+
54
+ def _url(self, suffix: str, query: dict[str, str | float | int] | None = None) -> str:
55
+ url = f"{self._base_url}{suffix}"
56
+ if not query:
57
+ return url
58
+ encoded = urlencode({key: str(value) for key, value in query.items()})
59
+ separator = "&" if "?" in url else "?"
60
+ return f"{url}{separator}{encoded}"
61
+
62
+ def _do_request(
63
+ self,
64
+ method: str,
65
+ suffix: str,
66
+ *,
67
+ payload: dict[str, Any] | None = None,
68
+ query: dict[str, str | float | int] | None = None,
69
+ timeout: float | None = None,
70
+ allow_not_found: bool = False,
71
+ ) -> tuple[int, Any | None]:
72
+ data = None if payload is None else json.dumps(payload).encode("utf-8")
73
+ request = Request(
74
+ self._url(suffix, query=query),
75
+ data=data,
76
+ method=method,
77
+ headers=self._headers(),
78
+ )
79
+ try:
80
+ with urlopen(request, timeout=timeout) as response:
81
+ raw = response.read()
82
+ if not raw:
83
+ return response.status, None
84
+ return response.status, json.loads(raw.decode("utf-8"))
85
+ except HTTPError as exc:
86
+ if exc.code in {204, 404} and allow_not_found:
87
+ return exc.code, None
88
+ message = exc.read().decode("utf-8", errors="replace")
89
+ raise RuntimeError(
90
+ f"HTTP backend request failed: {method} {suffix} -> {exc.code} {message}"
91
+ ) from exc
92
+ except URLError as exc:
93
+ raise ConnectionError(f"HTTP backend connection failed: {exc.reason}") from exc
94
+
95
+ async def _request(
96
+ self,
97
+ method: str,
98
+ suffix: str,
99
+ *,
100
+ payload: dict[str, Any] | None = None,
101
+ query: dict[str, str | float | int] | None = None,
102
+ timeout: float | None = None,
103
+ allow_not_found: bool = False,
104
+ ) -> tuple[int, Any | None]:
105
+ return await asyncio.to_thread(
106
+ self._do_request,
107
+ method,
108
+ suffix,
109
+ payload=payload,
110
+ query=query,
111
+ timeout=timeout,
112
+ allow_not_found=allow_not_found,
113
+ )
114
+
115
+ async def submit(self, task_json: str) -> None:
116
+ await self._request("POST", "/tasks", payload={"task_json": task_json})
117
+
118
+ async def listen(self) -> AsyncIterator[str]:
119
+ while True:
120
+ _status, body = await self._request(
121
+ "POST",
122
+ "/tasks/claim",
123
+ payload={"wait_seconds": _DEFAULT_LONG_POLL_SECONDS},
124
+ timeout=_DEFAULT_LONG_POLL_SECONDS + 5.0,
125
+ allow_not_found=True,
126
+ )
127
+ if body is None:
128
+ continue
129
+ task_json = body.get("task_json")
130
+ if isinstance(task_json, str):
131
+ yield task_json
132
+
133
+ async def send_result(self, task_id: str, result_json: str) -> None:
134
+ await self._request(
135
+ "POST",
136
+ f"/tasks/{task_id}/result",
137
+ payload={"result_json": result_json},
138
+ )
139
+
140
+ async def get_result(self, task_id: str, timeout: float | None = None) -> str:
141
+ deadline = None if timeout is None else time.monotonic() + timeout
142
+ while True:
143
+ remaining = None if deadline is None else max(0.0, deadline - time.monotonic())
144
+ wait_seconds = _DEFAULT_LONG_POLL_SECONDS if remaining is None else min(_DEFAULT_LONG_POLL_SECONDS, remaining)
145
+ status, body = await self._request(
146
+ "GET",
147
+ f"/tasks/{task_id}/result",
148
+ query={"wait_seconds": wait_seconds},
149
+ timeout=(wait_seconds + 5.0) if wait_seconds else 5.0,
150
+ allow_not_found=True,
151
+ )
152
+ if body is not None:
153
+ result_json = body.get("result_json")
154
+ if isinstance(result_json, str):
155
+ return result_json
156
+ if deadline is not None and (status == 204 or time.monotonic() >= deadline):
157
+ raise TimeoutError(f"Timed out waiting for result of task {task_id}")
158
+
159
+ async def try_get_result(self, task_id: str) -> str | None:
160
+ _status, body = await self._request(
161
+ "GET",
162
+ f"/tasks/{task_id}/result",
163
+ query={"wait_seconds": 0},
164
+ allow_not_found=True,
165
+ )
166
+ if body is None:
167
+ return None
168
+ result_json = body.get("result_json")
169
+ return result_json if isinstance(result_json, str) else None
170
+
171
+ async def send_heartbeat(self, task_id: str) -> None:
172
+ await self._request("POST", f"/tasks/{task_id}/heartbeat")
173
+
174
+ async def get_heartbeat(self, task_id: str) -> float | None:
175
+ _status, body = await self._request(
176
+ "GET", f"/tasks/{task_id}/heartbeat", allow_not_found=True,
177
+ )
178
+ if body is None:
179
+ return None
180
+ raw = body.get("heartbeat")
181
+ return float(raw) if isinstance(raw, (int, float)) else None
182
+
183
+ async def cancel_task(self, task_id: str) -> None:
184
+ await self._request("POST", f"/tasks/{task_id}/cancel")
185
+
186
+ async def is_cancelled(self, task_id: str) -> bool:
187
+ _status, body = await self._request(
188
+ "GET", f"/tasks/{task_id}/cancel", allow_not_found=True,
189
+ )
190
+ return bool(body and body.get("cancelled"))
191
+
192
+ async def send_progress(self, task_id: str, progress_json: str) -> None:
193
+ await self._request(
194
+ "POST",
195
+ f"/tasks/{task_id}/progress",
196
+ payload={"progress_json": progress_json},
197
+ )
198
+
199
+ async def get_progress(self, task_id: str) -> str | None:
200
+ _status, body = await self._request(
201
+ "GET", f"/tasks/{task_id}/progress", allow_not_found=True,
202
+ )
203
+ if body is None:
204
+ return None
205
+ progress_json = body.get("progress_json")
206
+ return progress_json if isinstance(progress_json, str) else None
207
+
208
+ async def cancel_schedule(self, schedule_id: str) -> None:
209
+ await self._request("POST", f"/schedules/{schedule_id}/cancel")
210
+
211
+ async def is_schedule_cancelled(self, schedule_id: str) -> bool:
212
+ _status, body = await self._request(
213
+ "GET", f"/schedules/{schedule_id}/cancel", allow_not_found=True,
214
+ )
215
+ return bool(body and body.get("cancelled"))
216
+
217
+ async def check_throttle(self, function_name: str) -> bool:
218
+ _status, body = await self._request(
219
+ "GET",
220
+ "/throttle/check",
221
+ query={"function_name": function_name},
222
+ allow_not_found=True,
223
+ )
224
+ return True if body is None else bool(body.get("allowed", True))
225
+
226
+ async def record_throttle(self, function_name: str, throttle_seconds: float) -> None:
227
+ await self._request(
228
+ "POST",
229
+ "/throttle/record",
230
+ payload={
231
+ "function_name": function_name,
232
+ "throttle_seconds": throttle_seconds,
233
+ },
234
+ )
235
+
236
+ async def close(self) -> None:
237
+ return None
@@ -0,0 +1,452 @@
1
+ """Async-native TCP backend for same-machine IPC.
2
+
3
+ A lightweight broker server built on :mod:`asyncio` TCP streams handles
4
+ task dispatch, result routing, and heartbeats -- no threads, no
5
+ :mod:`multiprocessing`, no external services.
6
+
7
+ URL scheme: ``local://host:port`` (default ``local://127.0.0.1:9748``)
8
+ """
9
+
10
+ import sys
11
+ import json
12
+ import time
13
+ import atexit
14
+ import socket
15
+ import struct
16
+ import asyncio
17
+ import logging
18
+ import contextlib
19
+ import subprocess
20
+ from typing import Any
21
+ from urllib.parse import urlparse
22
+ from collections.abc import AsyncIterator
23
+
24
+ from offwork.worker.backends.base import Backend
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ _HEADER = struct.Struct("!I") # 4-byte big-endian length prefix
29
+ _DEFAULT_HOST = "127.0.0.1"
30
+ _DEFAULT_PORT = 9748
31
+
32
+
33
+ # ---------------------------------------------------------------------------
34
+ # Wire protocol
35
+ # ---------------------------------------------------------------------------
36
+
37
+
38
+ async def _send_msg(writer: asyncio.StreamWriter, obj: dict[str, Any]) -> None:
39
+ """Send a length-prefixed JSON message."""
40
+ payload = json.dumps(obj, separators=(",", ":")).encode()
41
+ writer.write(_HEADER.pack(len(payload)) + payload)
42
+ await writer.drain()
43
+
44
+
45
+ async def _recv_msg(reader: asyncio.StreamReader) -> dict[str, Any]:
46
+ """Receive a length-prefixed JSON message.
47
+
48
+ Raises :class:`asyncio.IncompleteReadError` on EOF.
49
+ """
50
+ raw = await reader.readexactly(_HEADER.size)
51
+ (length,) = _HEADER.unpack(raw)
52
+ data = await reader.readexactly(length)
53
+ result: dict[str, Any] = json.loads(data)
54
+ return result
55
+
56
+
57
+ # ---------------------------------------------------------------------------
58
+ # Broker server (pure asyncio)
59
+ # ---------------------------------------------------------------------------
60
+
61
+
62
+ class _Broker:
63
+ """Task broker backed entirely by asyncio primitives."""
64
+
65
+ def __init__(self) -> None:
66
+ self._tasks: asyncio.Queue[str] = asyncio.Queue()
67
+ self._results: dict[str, asyncio.Queue[str]] = {}
68
+ self._heartbeats: dict[str, float] = {}
69
+ self._cancelled: set[str] = set()
70
+ self._cancelled_schedules: set[str] = set()
71
+ self._throttles: dict[str, float] = {}
72
+ self._progress: dict[str, str] = {}
73
+ self._result_subs: list[asyncio.Queue[str]] = []
74
+
75
+ def _result_slot(self, task_id: str) -> asyncio.Queue[str]:
76
+ if task_id not in self._results:
77
+ self._results[task_id] = asyncio.Queue(maxsize=1)
78
+ return self._results[task_id]
79
+
80
+ # -- connection handler ----------------------------------------------------
81
+
82
+ async def handle(
83
+ self,
84
+ reader: asyncio.StreamReader,
85
+ writer: asyncio.StreamWriter,
86
+ ) -> None:
87
+ try:
88
+ msg = await _recv_msg(reader)
89
+ op = msg.get("op", "")
90
+ if op == "listen":
91
+ await self._stream_tasks(writer)
92
+ elif op == "subscribe":
93
+ await self._stream_results(writer)
94
+ else:
95
+ await self._rpc_loop(msg, reader, writer)
96
+ except (asyncio.IncompleteReadError, ConnectionError, OSError):
97
+ pass
98
+ finally:
99
+ writer.close()
100
+ with contextlib.suppress(ConnectionError, OSError):
101
+ await writer.wait_closed()
102
+
103
+ async def _rpc_loop(
104
+ self,
105
+ first: dict[str, Any],
106
+ reader: asyncio.StreamReader,
107
+ writer: asyncio.StreamWriter,
108
+ ) -> None:
109
+ await _send_msg(writer, self._dispatch(first))
110
+ while True:
111
+ msg = await _recv_msg(reader)
112
+ await _send_msg(writer, self._dispatch(msg))
113
+
114
+ # -- dispatch (sync -- no awaits, safe for single-threaded asyncio) --------
115
+
116
+ def _dispatch(self, msg: dict[str, Any]) -> dict[str, Any]:
117
+ op = msg["op"]
118
+ if op == "submit":
119
+ self._tasks.put_nowait(msg["data"])
120
+ return {"ok": True}
121
+ if op == "result_put":
122
+ try:
123
+ self._result_slot(msg["task_id"]).put_nowait(msg["data"])
124
+ except asyncio.QueueFull:
125
+ pass # first result wins (e.g., cancel before worker result)
126
+ for sub in self._result_subs:
127
+ sub.put_nowait(msg["task_id"])
128
+ return {"ok": True}
129
+ if op == "result_try":
130
+ try:
131
+ return {"ok": True, "data": self._result_slot(msg["task_id"]).get_nowait()}
132
+ except asyncio.QueueEmpty:
133
+ return {"ok": True, "data": None}
134
+ if op == "hb_put":
135
+ self._heartbeats[msg["task_id"]] = msg["ts"]
136
+ return {"ok": True}
137
+ if op == "hb_get":
138
+ return {"ok": True, "data": self._heartbeats.get(msg["task_id"])}
139
+ if op == "hb_batch":
140
+ return {
141
+ "ok": True,
142
+ "data": {tid: self._heartbeats.get(tid) for tid in msg["task_ids"]},
143
+ }
144
+ if op == "cancel":
145
+ self._cancelled.add(msg["task_id"])
146
+ return {"ok": True}
147
+ if op == "is_cancelled":
148
+ return {"ok": True, "data": msg["task_id"] in self._cancelled}
149
+ if op == "progress_put":
150
+ self._progress[msg["task_id"]] = msg["data"]
151
+ return {"ok": True}
152
+ if op == "progress_get":
153
+ return {"ok": True, "data": self._progress.get(msg["task_id"])}
154
+ if op == "schedule_cancel":
155
+ self._cancelled_schedules.add(msg["schedule_id"])
156
+ return {"ok": True}
157
+ if op == "schedule_check":
158
+ return {"ok": True, "data": msg["schedule_id"] in self._cancelled_schedules}
159
+ if op == "throttle_check":
160
+ fn = msg["function_name"]
161
+ expiry = self._throttles.get(fn)
162
+ if expiry is not None and time.time() < expiry:
163
+ return {"ok": True, "data": False}
164
+ self._throttles.pop(fn, None)
165
+ return {"ok": True, "data": True}
166
+ if op == "throttle_record":
167
+ fn = msg["function_name"]
168
+ self._throttles[fn] = time.time() + msg["seconds"]
169
+ return {"ok": True}
170
+ return {"ok": False, "error": f"unknown op: {op}"}
171
+
172
+ # -- streaming handlers ----------------------------------------------------
173
+
174
+ async def _stream_tasks(self, writer: asyncio.StreamWriter) -> None:
175
+ while True:
176
+ try:
177
+ task = await asyncio.wait_for(self._tasks.get(), timeout=1.0)
178
+ except asyncio.TimeoutError:
179
+ continue
180
+ try:
181
+ await _send_msg(writer, {"data": task})
182
+ except (ConnectionError, OSError):
183
+ # Client gone -- put the task back so another listener gets it.
184
+ self._tasks.put_nowait(task)
185
+ return
186
+
187
+ async def _stream_results(self, writer: asyncio.StreamWriter) -> None:
188
+ q: asyncio.Queue[str] = asyncio.Queue()
189
+ self._result_subs.append(q)
190
+ try:
191
+ while True:
192
+ try:
193
+ task_id = await asyncio.wait_for(q.get(), timeout=1.0)
194
+ except asyncio.TimeoutError:
195
+ continue
196
+ try:
197
+ await _send_msg(writer, {"data": task_id})
198
+ except (ConnectionError, OSError):
199
+ return
200
+ finally:
201
+ self._result_subs.remove(q)
202
+
203
+
204
+ async def run_broker(host: str, port: int) -> None:
205
+ """Start the broker TCP server (runs forever)."""
206
+ broker = _Broker()
207
+ server = await asyncio.start_server(broker.handle, host, port)
208
+ logger.info("Local broker listening on %s:%d", host, port)
209
+ async with server:
210
+ await server.serve_forever()
211
+
212
+
213
+ def _broker_main(host: str, port: int) -> None:
214
+ """Entry point for the broker subprocess."""
215
+ asyncio.run(run_broker(host, port))
216
+
217
+
218
+ # ---------------------------------------------------------------------------
219
+ # URL parsing
220
+ # ---------------------------------------------------------------------------
221
+
222
+
223
+ def _parse_local_url(url: str) -> tuple[str, int]:
224
+ parsed = urlparse(url)
225
+ host = parsed.hostname or _DEFAULT_HOST
226
+ port = parsed.port or _DEFAULT_PORT
227
+ return host, port
228
+
229
+
230
+ # ---------------------------------------------------------------------------
231
+ # LocalBackend
232
+ # ---------------------------------------------------------------------------
233
+
234
+
235
+ class LocalBackend(Backend):
236
+ """Async-native TCP backend for same-machine IPC.
237
+
238
+ A lightweight broker process handles task dispatch, result routing,
239
+ and heartbeats over TCP on localhost. Every I/O operation is a
240
+ native :mod:`asyncio` coroutine -- no threads anywhere.
241
+
242
+ Parameters
243
+ ----------
244
+ url
245
+ ``local://host:port`` (default ``local://127.0.0.1:9748``).
246
+ server
247
+ ``True`` to start the broker, ``False`` to connect to an
248
+ existing one, ``None`` (default) to auto-detect.
249
+ """
250
+
251
+ def __init__(
252
+ self,
253
+ url: str = "local://localhost",
254
+ *,
255
+ server: bool | None = None,
256
+ ) -> None:
257
+ self._host, self._port = _parse_local_url(url)
258
+ self._reader: asyncio.StreamReader | None = None
259
+ self._writer: asyncio.StreamWriter | None = None
260
+ self._lock = asyncio.Lock()
261
+ self._server_proc: subprocess.Popen[bytes] | None = None
262
+
263
+ self._ensure_broker(server)
264
+ logger.info(
265
+ "LocalBackend ready (server=%s, %s:%d)",
266
+ self._server_proc is not None, self._host, self._port,
267
+ )
268
+
269
+ # -- broker lifecycle ------------------------------------------------------
270
+
271
+ def _ensure_broker(self, server: bool | None) -> None:
272
+ if server is False:
273
+ return
274
+ if self._probe():
275
+ return
276
+ self._start_broker()
277
+
278
+ def _probe(self) -> bool:
279
+ """Check whether a broker is already accepting connections."""
280
+ try:
281
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
282
+ s.settimeout(0.5)
283
+ s.connect((self._host, self._port))
284
+ s.close()
285
+ return True
286
+ except (ConnectionRefusedError, OSError):
287
+ return False
288
+
289
+ def _start_broker(self) -> None:
290
+ self._server_proc = subprocess.Popen(
291
+ [
292
+ sys.executable, "-c",
293
+ "from offwork.worker.backends.local import _broker_main; "
294
+ f"_broker_main({self._host!r}, {self._port})",
295
+ ],
296
+ stdout=subprocess.DEVNULL,
297
+ stderr=subprocess.DEVNULL,
298
+ )
299
+ atexit.register(self._kill_broker)
300
+ for _ in range(50): # up to 5 s
301
+ if self._probe():
302
+ return
303
+ time.sleep(0.1)
304
+ raise ConnectionError(
305
+ f"Broker failed to start on {self._host}:{self._port}"
306
+ )
307
+
308
+ def _kill_broker(self) -> None:
309
+ p = self._server_proc
310
+ if p is not None and p.poll() is None:
311
+ p.terminate()
312
+ try:
313
+ p.wait(timeout=5)
314
+ except subprocess.TimeoutExpired:
315
+ p.kill()
316
+ self._server_proc = None
317
+
318
+ # -- TCP connection --------------------------------------------------------
319
+
320
+ async def _connect(self) -> tuple[asyncio.StreamReader, asyncio.StreamWriter]:
321
+ if self._reader is None or self._writer is None:
322
+ self._reader, self._writer = await asyncio.open_connection(
323
+ self._host, self._port,
324
+ )
325
+ return self._reader, self._writer
326
+
327
+ async def _request(self, msg: dict[str, Any]) -> dict[str, Any]:
328
+ async with self._lock:
329
+ try:
330
+ reader, writer = await self._connect()
331
+ await _send_msg(writer, msg)
332
+ return await _recv_msg(reader)
333
+ except (asyncio.IncompleteReadError, ConnectionError, OSError):
334
+ if self._writer is not None:
335
+ self._writer.close()
336
+ self._reader = self._writer = None
337
+ raise
338
+
339
+ # -- Backend interface -----------------------------------------------------
340
+
341
+ async def submit(self, task_json: str) -> None:
342
+ await self._request({"op": "submit", "data": task_json})
343
+
344
+ async def listen(self) -> AsyncIterator[str]:
345
+ reader, writer = await asyncio.open_connection(self._host, self._port)
346
+ try:
347
+ await _send_msg(writer, {"op": "listen"})
348
+ while True:
349
+ msg = await _recv_msg(reader)
350
+ yield msg["data"]
351
+ except (asyncio.IncompleteReadError, ConnectionError, OSError):
352
+ return
353
+ finally:
354
+ writer.close()
355
+ with contextlib.suppress(ConnectionError, OSError):
356
+ await writer.wait_closed()
357
+
358
+ async def send_result(self, task_id: str, result_json: str) -> None:
359
+ await self._request({
360
+ "op": "result_put", "task_id": task_id, "data": result_json,
361
+ })
362
+
363
+ async def get_result(self, task_id: str, timeout: float | None = None) -> str:
364
+ deadline = None if timeout is None else time.monotonic() + timeout
365
+ while True:
366
+ raw = await self.try_get_result(task_id)
367
+ if raw is not None:
368
+ return raw
369
+ if deadline is not None and time.monotonic() >= deadline:
370
+ raise TimeoutError(
371
+ f"Timed out waiting for result of task {task_id}"
372
+ )
373
+ await asyncio.sleep(0.05)
374
+
375
+ async def try_get_result(self, task_id: str) -> str | None:
376
+ resp = await self._request({"op": "result_try", "task_id": task_id})
377
+ return resp.get("data")
378
+
379
+ async def send_heartbeat(self, task_id: str) -> None:
380
+ await self._request({
381
+ "op": "hb_put", "task_id": task_id, "ts": time.time(),
382
+ })
383
+
384
+ async def get_heartbeat(self, task_id: str) -> float | None:
385
+ resp = await self._request({"op": "hb_get", "task_id": task_id})
386
+ return resp.get("data")
387
+
388
+ async def get_heartbeats(self, task_ids: list[str]) -> dict[str, float | None]:
389
+ resp = await self._request({"op": "hb_batch", "task_ids": task_ids})
390
+ result: dict[str, float | None] = resp.get("data", {})
391
+ return result
392
+
393
+ async def cancel_task(self, task_id: str) -> None:
394
+ await self._request({"op": "cancel", "task_id": task_id})
395
+
396
+ async def is_cancelled(self, task_id: str) -> bool:
397
+ resp = await self._request({"op": "is_cancelled", "task_id": task_id})
398
+ return bool(resp.get("data", False))
399
+
400
+ async def send_progress(self, task_id: str, progress_json: str) -> None:
401
+ await self._request({
402
+ "op": "progress_put", "task_id": task_id, "data": progress_json,
403
+ })
404
+
405
+ async def get_progress(self, task_id: str) -> str | None:
406
+ resp = await self._request({"op": "progress_get", "task_id": task_id})
407
+ return resp.get("data")
408
+
409
+ async def cancel_schedule(self, schedule_id: str) -> None:
410
+ await self._request({"op": "schedule_cancel", "schedule_id": schedule_id})
411
+
412
+ async def is_schedule_cancelled(self, schedule_id: str) -> bool:
413
+ resp = await self._request({"op": "schedule_check", "schedule_id": schedule_id})
414
+ return bool(resp.get("data", False))
415
+
416
+ async def check_throttle(self, function_name: str) -> bool:
417
+ resp = await self._request({"op": "throttle_check", "function_name": function_name})
418
+ return bool(resp.get("data", True))
419
+
420
+ async def record_throttle(
421
+ self, function_name: str, throttle_seconds: float,
422
+ ) -> None:
423
+ await self._request({
424
+ "op": "throttle_record",
425
+ "function_name": function_name,
426
+ "seconds": throttle_seconds,
427
+ })
428
+
429
+ async def notify_result(self, task_id: str) -> None:
430
+ pass # broker dispatches notifications inside result_put
431
+
432
+ async def subscribe_results(self) -> AsyncIterator[str]:
433
+ reader, writer = await asyncio.open_connection(self._host, self._port)
434
+ try:
435
+ await _send_msg(writer, {"op": "subscribe"})
436
+ while True:
437
+ msg = await _recv_msg(reader)
438
+ yield msg["data"]
439
+ except (asyncio.IncompleteReadError, ConnectionError, OSError):
440
+ return
441
+ finally:
442
+ writer.close()
443
+ with contextlib.suppress(ConnectionError, OSError):
444
+ await writer.wait_closed()
445
+
446
+ async def close(self) -> None:
447
+ if self._writer is not None:
448
+ self._writer.close()
449
+ with contextlib.suppress(ConnectionError, OSError):
450
+ await self._writer.wait_closed()
451
+ self._reader = self._writer = None
452
+ self._kill_broker()