offwork 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. offwork/__init__.py +167 -0
  2. offwork/__main__.py +770 -0
  3. offwork/_venv.py +174 -0
  4. offwork/core/__init__.py +15 -0
  5. offwork/core/errors.py +83 -0
  6. offwork/core/models.py +174 -0
  7. offwork/core/pairing.py +389 -0
  8. offwork/core/progress.py +91 -0
  9. offwork/core/signing.py +91 -0
  10. offwork/core/task.py +520 -0
  11. offwork/core/token.py +184 -0
  12. offwork/core/version.py +10 -0
  13. offwork/graph/__init__.py +5 -0
  14. offwork/graph/analyzer.py +637 -0
  15. offwork/graph/decorator.py +87 -0
  16. offwork/graph/graph.py +995 -0
  17. offwork/graph/store.py +500 -0
  18. offwork/graph/tracing.py +429 -0
  19. offwork/py.typed +0 -0
  20. offwork/typing.py +48 -0
  21. offwork/worker/__init__.py +18 -0
  22. offwork/worker/backends/__init__.py +3 -0
  23. offwork/worker/backends/base.py +149 -0
  24. offwork/worker/backends/http.py +237 -0
  25. offwork/worker/backends/local.py +452 -0
  26. offwork/worker/backends/rabbitmq.py +410 -0
  27. offwork/worker/backends/redis.py +175 -0
  28. offwork/worker/deps.py +365 -0
  29. offwork/worker/remote.py +793 -0
  30. offwork/worker/result.py +276 -0
  31. offwork/worker/sandbox/Dockerfile +24 -0
  32. offwork/worker/sandbox/__init__.py +18 -0
  33. offwork/worker/sandbox/_protocol.py +50 -0
  34. offwork/worker/sandbox/docker.py +438 -0
  35. offwork/worker/sandbox/guest_agent.py +622 -0
  36. offwork/worker/schedule.py +26 -0
  37. offwork/worker/worker.py +263 -0
  38. offwork-0.4.0.dist-info/METADATA +143 -0
  39. offwork-0.4.0.dist-info/RECORD +42 -0
  40. offwork-0.4.0.dist-info/WHEEL +4 -0
  41. offwork-0.4.0.dist-info/entry_points.txt +3 -0
  42. offwork-0.4.0.dist-info/licenses/LICENSE +661 -0
offwork/core/task.py ADDED
@@ -0,0 +1,520 @@
1
+ """Task dataclass: serializable envelope bundling a graph with arguments."""
2
+
3
+ import base64
4
+ import collections
5
+ import datetime as _dt
6
+ import enum
7
+ import ipaddress
8
+ import json
9
+ import pathlib
10
+ import pickle
11
+ import re
12
+ import uuid
13
+ from decimal import Decimal
14
+ from fractions import Fraction
15
+ from typing import Any, Self
16
+ from dataclasses import field, dataclass
17
+
18
+ from offwork.core.errors import SignatureError
19
+ from offwork.core.signing import verify_signature, compute_signature
20
+
21
+ _OBJECT_SENTINEL = "__offwork_obj__"
22
+ _BYTES_SENTINEL = "__offwork_bytes__"
23
+ _BUILTIN_SENTINEL = "__offwork_builtin__"
24
+ _TUPLE_SENTINEL = "__offwork_tuple__"
25
+ _DICT_SENTINEL = "__offwork_dict__"
26
+ _PICKLE_SENTINEL = "__offwork_pickle__"
27
+
28
+
29
+ _FACTORY_BY_NAME: dict[str, Any] = {
30
+ "int": int, "list": list, "dict": dict, "set": set,
31
+ "tuple": tuple, "frozenset": frozenset, "str": str, "float": float,
32
+ "bytes": bytes, "bool": bool,
33
+ }
34
+
35
+
36
+ def _encode_factory(factory: Any) -> str | None:
37
+ """Encode a defaultdict factory if it's a recognised builtin, else None."""
38
+ if factory is None:
39
+ return None
40
+ name = getattr(factory, "__name__", None)
41
+ if isinstance(name, str) and _FACTORY_BY_NAME.get(name) is factory:
42
+ return name
43
+ return None
44
+
45
+
46
+ def _encode_builtin(o: object) -> dict[str, Any] | None:
47
+ """Encode common stdlib types to a JSON-safe sentinel payload.
48
+
49
+ Returns ``None`` if *o* is not a recognised builtin -- the caller
50
+ falls back to object-state serialization or pickle.
51
+ """
52
+ # IntEnum / StrEnum subclass int / str, so check enum first.
53
+ if isinstance(o, enum.Enum):
54
+ return {
55
+ "type": "enum",
56
+ "cls": type(o).__name__,
57
+ "name": o.name,
58
+ "value": _to_jsonable(o.value),
59
+ }
60
+ # datetime is a subclass of date, so check it first.
61
+ if isinstance(o, _dt.datetime):
62
+ return {"type": "datetime", "value": o.isoformat()}
63
+ if isinstance(o, _dt.date):
64
+ return {"type": "date", "value": o.isoformat()}
65
+ if isinstance(o, _dt.time):
66
+ return {"type": "time", "value": o.isoformat()}
67
+ if isinstance(o, _dt.timedelta):
68
+ return {"type": "timedelta", "value": o.total_seconds()}
69
+ if isinstance(o, Decimal):
70
+ return {"type": "decimal", "value": str(o)}
71
+ if isinstance(o, Fraction):
72
+ return {"type": "fraction", "value": [o.numerator, o.denominator]}
73
+ if isinstance(o, uuid.UUID):
74
+ return {"type": "uuid", "value": o.hex}
75
+ if isinstance(o, complex):
76
+ return {"type": "complex", "value": [o.real, o.imag]}
77
+ if isinstance(o, range):
78
+ return {"type": "range", "value": [o.start, o.stop, o.step]}
79
+ if isinstance(o, frozenset):
80
+ return {"type": "frozenset", "value": [_to_jsonable(v) for v in o]}
81
+ if isinstance(o, set):
82
+ return {"type": "set", "value": [_to_jsonable(v) for v in o]}
83
+ if isinstance(o, collections.deque):
84
+ return {
85
+ "type": "deque",
86
+ "value": [_to_jsonable(v) for v in o],
87
+ "maxlen": o.maxlen,
88
+ }
89
+ if isinstance(o, pathlib.PurePath):
90
+ return {"type": "path", "value": str(o), "cls": type(o).__name__}
91
+ if isinstance(
92
+ o,
93
+ (
94
+ ipaddress.IPv4Address, ipaddress.IPv6Address,
95
+ ipaddress.IPv4Network, ipaddress.IPv6Network,
96
+ ipaddress.IPv4Interface, ipaddress.IPv6Interface,
97
+ ),
98
+ ):
99
+ return {"type": "ipaddress", "cls": type(o).__name__, "value": str(o)}
100
+ return None
101
+
102
+
103
+ _PATH_CLASSES: dict[str, type[pathlib.PurePath]] = {
104
+ "PurePath": pathlib.PurePath,
105
+ "PurePosixPath": pathlib.PurePosixPath,
106
+ "PureWindowsPath": pathlib.PureWindowsPath,
107
+ "Path": pathlib.Path,
108
+ "PosixPath": pathlib.PurePosixPath,
109
+ "WindowsPath": pathlib.PureWindowsPath,
110
+ }
111
+
112
+ _IP_CLASSES: dict[str, Any] = {
113
+ "IPv4Address": ipaddress.IPv4Address,
114
+ "IPv6Address": ipaddress.IPv6Address,
115
+ "IPv4Network": ipaddress.IPv4Network,
116
+ "IPv6Network": ipaddress.IPv6Network,
117
+ "IPv4Interface": ipaddress.IPv4Interface,
118
+ "IPv6Interface": ipaddress.IPv6Interface,
119
+ }
120
+
121
+
122
+ def _decode_builtin(info: dict[str, Any], namespace: dict[str, Any]) -> Any:
123
+ """Reverse :func:`_encode_builtin`."""
124
+ kind = info.get("type")
125
+ raw: Any = info.get("value")
126
+ if kind == "datetime":
127
+ return _dt.datetime.fromisoformat(str(raw))
128
+ if kind == "date":
129
+ return _dt.date.fromisoformat(str(raw))
130
+ if kind == "time":
131
+ return _dt.time.fromisoformat(str(raw))
132
+ if kind == "timedelta":
133
+ return _dt.timedelta(seconds=float(raw))
134
+ if kind == "decimal":
135
+ return Decimal(str(raw))
136
+ if kind == "fraction":
137
+ return Fraction(int(raw[0]), int(raw[1]))
138
+ if kind == "uuid":
139
+ return uuid.UUID(hex=str(raw))
140
+ if kind == "complex":
141
+ return complex(raw[0], raw[1])
142
+ if kind == "range":
143
+ return range(raw[0], raw[1], raw[2])
144
+ if kind == "set":
145
+ return {_resolve(v, namespace) for v in raw}
146
+ if kind == "frozenset":
147
+ return frozenset(_resolve(v, namespace) for v in raw)
148
+ if kind == "deque":
149
+ return collections.deque(
150
+ (_resolve(v, namespace) for v in raw),
151
+ maxlen=info.get("maxlen"),
152
+ )
153
+ if kind == "counter":
154
+ return collections.Counter({
155
+ _resolve(k, namespace): v for k, v in info["items"]
156
+ })
157
+ if kind == "ordereddict":
158
+ return collections.OrderedDict(
159
+ (_resolve(k, namespace), _resolve(v, namespace))
160
+ for k, v in info["items"]
161
+ )
162
+ if kind == "defaultdict":
163
+ factory = _FACTORY_BY_NAME.get(info.get("factory") or "")
164
+ dd: collections.defaultdict[Any, Any] = collections.defaultdict(factory)
165
+ for k, v in info["items"]:
166
+ dd[_resolve(k, namespace)] = _resolve(v, namespace)
167
+ return dd
168
+ if kind == "namedtuple":
169
+ cls = namespace.get(info["cls"])
170
+ values = [_resolve(v, namespace) for v in info["values"]]
171
+ if cls is None:
172
+ return tuple(values)
173
+ return cls(*values)
174
+ if kind == "enum":
175
+ cls = namespace.get(info["cls"])
176
+ if cls is None:
177
+ return _resolve(raw, namespace)
178
+ try:
179
+ return cls[info["name"]]
180
+ except KeyError:
181
+ return cls(_resolve(raw, namespace))
182
+ if kind == "path":
183
+ # Try to honour the original class; fall back to a sensible
184
+ # OS-portable default if the concrete subclass cannot be
185
+ # instantiated on this platform.
186
+ cls = _PATH_CLASSES.get(info.get("cls", ""), pathlib.PurePath)
187
+ try:
188
+ return cls(str(raw))
189
+ except (NotImplementedError, TypeError):
190
+ return pathlib.PurePath(str(raw))
191
+ if kind == "ipaddress":
192
+ ip_cls = _IP_CLASSES.get(info.get("cls", ""))
193
+ if ip_cls is None:
194
+ return str(raw)
195
+ return ip_cls(str(raw))
196
+ raise ValueError(f"Unknown builtin sentinel type: {kind!r}")
197
+
198
+
199
+ def _extract_object_state(o: object) -> dict[str, Any] | None:
200
+ """Return the per-instance state dict, or ``None`` if not extractable."""
201
+ if hasattr(o, "__dict__"):
202
+ d = getattr(o, "__dict__", None)
203
+ if isinstance(d, dict):
204
+ return dict(d)
205
+ if hasattr(type(o), "__slots__"):
206
+ all_slots: set[str] = set()
207
+ for klass in type(o).__mro__:
208
+ all_slots.update(getattr(klass, "__slots__", ()))
209
+ all_slots -= {"__weakref__", "__dict__"}
210
+ return {
211
+ slot: getattr(o, slot)
212
+ for slot in sorted(all_slots)
213
+ if hasattr(o, slot)
214
+ }
215
+ return None
216
+
217
+
218
+ def _to_jsonable(o: Any) -> Any:
219
+ """Recursively convert *o* to a JSON-safe value using sentinels.
220
+
221
+ Order of checks matters: ``bool`` and ``IntEnum`` subclass ``int``;
222
+ ``Counter``/``OrderedDict``/``defaultdict`` subclass ``dict``;
223
+ ``NamedTuple`` subclasses ``tuple``.
224
+ """
225
+ # Primitives. None / bool / str pass through; bool must come before int
226
+ # but JSON treats True/False natively so isinstance(_, int) is harmless
227
+ # *after* the enum check.
228
+ if o is None or isinstance(o, (str, bool)):
229
+ return o
230
+ if isinstance(o, enum.Enum):
231
+ return {_BUILTIN_SENTINEL: _encode_builtin(o)}
232
+ if isinstance(o, (int, float)):
233
+ return o
234
+ if isinstance(o, (bytes, bytearray)):
235
+ return {
236
+ _BYTES_SENTINEL: {
237
+ "data": base64.b64encode(bytes(o)).decode("ascii"),
238
+ "type": type(o).__name__,
239
+ }
240
+ }
241
+ if isinstance(o, memoryview):
242
+ return {
243
+ _BYTES_SENTINEL: {
244
+ "data": base64.b64encode(bytes(o)).decode("ascii"),
245
+ "type": "memoryview",
246
+ }
247
+ }
248
+ # NamedTuple before tuple (NamedTuple subclasses tuple).
249
+ if isinstance(o, tuple):
250
+ if hasattr(o, "_fields") and hasattr(o, "_asdict"):
251
+ return {
252
+ _BUILTIN_SENTINEL: {
253
+ "type": "namedtuple",
254
+ "cls": type(o).__name__,
255
+ "fields": list(o._fields),
256
+ "values": [_to_jsonable(v) for v in o],
257
+ }
258
+ }
259
+ return {_TUPLE_SENTINEL: [_to_jsonable(v) for v in o]}
260
+ if isinstance(o, list):
261
+ return [_to_jsonable(v) for v in o]
262
+ if isinstance(o, dict):
263
+ # dict subclasses (Counter, OrderedDict, defaultdict) first.
264
+ if isinstance(o, collections.Counter):
265
+ return {
266
+ _BUILTIN_SENTINEL: {
267
+ "type": "counter",
268
+ "items": [[_to_jsonable(k), v] for k, v in o.items()],
269
+ }
270
+ }
271
+ if isinstance(o, collections.OrderedDict):
272
+ return {
273
+ _BUILTIN_SENTINEL: {
274
+ "type": "ordereddict",
275
+ "items": [
276
+ [_to_jsonable(k), _to_jsonable(v)] for k, v in o.items()
277
+ ],
278
+ }
279
+ }
280
+ if isinstance(o, collections.defaultdict):
281
+ return {
282
+ _BUILTIN_SENTINEL: {
283
+ "type": "defaultdict",
284
+ "factory": _encode_factory(o.default_factory),
285
+ "items": [
286
+ [_to_jsonable(k), _to_jsonable(v)] for k, v in o.items()
287
+ ],
288
+ }
289
+ }
290
+ if all(isinstance(k, str) for k in o):
291
+ return {k: _to_jsonable(v) for k, v in o.items()}
292
+ return {
293
+ _DICT_SENTINEL: [
294
+ [_to_jsonable(k), _to_jsonable(v)] for k, v in o.items()
295
+ ]
296
+ }
297
+ builtin = _encode_builtin(o)
298
+ if builtin is not None:
299
+ return {_BUILTIN_SENTINEL: builtin}
300
+ state = _extract_object_state(o)
301
+ if state is not None:
302
+ return {
303
+ _OBJECT_SENTINEL: {
304
+ "class": type(o).__name__,
305
+ "state": {k: _to_jsonable(v) for k, v in state.items()},
306
+ }
307
+ }
308
+ # Last-resort: pickle. The task envelope is HMAC-signed end-to-end
309
+ # so unpickling on the worker is no more dangerous than the existing
310
+ # ``exec`` of reconstructed source.
311
+ try:
312
+ data = pickle.dumps(o)
313
+ except Exception as exc:
314
+ raise TypeError(
315
+ f"Object of type {type(o).__name__} is not serializable: {exc}"
316
+ ) from exc
317
+ return {_PICKLE_SENTINEL: base64.b64encode(data).decode("ascii")}
318
+
319
+
320
+ class _TaskEncoder(json.JSONEncoder):
321
+ """JSON encoder that pre-walks the tree to apply offwork sentinels.
322
+
323
+ JSON's native handling of ``tuple``/``dict``/``list`` would bypass
324
+ sentinels, so :meth:`iterencode` preprocesses the full tree via
325
+ :func:`_to_jsonable` before delegating to the base encoder.
326
+ Both :meth:`encode` and :meth:`iterencode` route through here.
327
+ """
328
+
329
+ def iterencode(self, o: Any, _one_shot: bool = False) -> Any:
330
+ return super().iterencode(_to_jsonable(o), _one_shot)
331
+
332
+ def default(self, o: object) -> Any: # pragma: no cover - unreachable
333
+ return _to_jsonable(o)
334
+
335
+
336
+ def _reconstruct_object(info: dict[str, Any], namespace: dict[str, Any]) -> Any:
337
+ """Rebuild a single serialized object from its sentinel payload."""
338
+ cls = namespace.get(info["class"])
339
+ if cls is None:
340
+ return {_OBJECT_SENTINEL: info}
341
+ obj = cls.__new__(cls)
342
+ state = {k: _resolve(v, namespace) for k, v in info.get("state", {}).items()}
343
+ if hasattr(obj, "__dict__"):
344
+ obj.__dict__.update(state)
345
+ else:
346
+ for key, val in state.items():
347
+ object.__setattr__(obj, key, val)
348
+ return obj
349
+
350
+
351
+ def _resolve(value: Any, namespace: dict[str, Any]) -> Any:
352
+ """Recursively resolve serialized object sentinels using *namespace*."""
353
+ if isinstance(value, list):
354
+ return [_resolve(v, namespace) for v in value]
355
+ if not isinstance(value, dict):
356
+ return value
357
+ if len(value) == 1:
358
+ if _OBJECT_SENTINEL in value:
359
+ return _reconstruct_object(value[_OBJECT_SENTINEL], namespace)
360
+ if _BYTES_SENTINEL in value:
361
+ info = value[_BYTES_SENTINEL]
362
+ raw = base64.b64decode(info["data"])
363
+ kind = info.get("type")
364
+ if kind == "bytearray":
365
+ return bytearray(raw)
366
+ if kind == "memoryview":
367
+ return memoryview(raw)
368
+ return raw
369
+ if _BUILTIN_SENTINEL in value:
370
+ return _decode_builtin(value[_BUILTIN_SENTINEL], namespace)
371
+ if _TUPLE_SENTINEL in value:
372
+ return tuple(_resolve(v, namespace) for v in value[_TUPLE_SENTINEL])
373
+ if _DICT_SENTINEL in value:
374
+ return {
375
+ _resolve(k, namespace): _resolve(v, namespace)
376
+ for k, v in value[_DICT_SENTINEL]
377
+ }
378
+ if _PICKLE_SENTINEL in value:
379
+ return pickle.loads(base64.b64decode(value[_PICKLE_SENTINEL]))
380
+ return {k: _resolve(v, namespace) for k, v in value.items()}
381
+
382
+
383
+ def resolve_args(
384
+ args: tuple[Any, ...],
385
+ kwargs: dict[str, Any],
386
+ namespace: dict[str, Any],
387
+ ) -> tuple[tuple[Any, ...], dict[str, Any]]:
388
+ """Resolve serialized object sentinels in task arguments.
389
+
390
+ Called by the worker after reconstructing the function's namespace,
391
+ so that class instances passed as arguments can be rebuilt.
392
+ """
393
+ return (
394
+ tuple(_resolve(a, namespace) for a in args),
395
+ {k: _resolve(v, namespace) for k, v in kwargs.items()},
396
+ )
397
+
398
+
399
+ @dataclass(frozen=True)
400
+ class Task:
401
+ """A serializable envelope for remote function execution.
402
+
403
+ Bundles the serialized dependency graph with the target function
404
+ name and its arguments, so the consumer side needs zero knowledge
405
+ of offwork internals to dispatch work.
406
+
407
+ When a shared key is provided (via :meth:`to_signed_json`), the
408
+ serialized payload carries an HMAC-SHA256 signature that the worker
409
+ can verify before execution.
410
+ """
411
+
412
+ graph_json: str
413
+ function_name: str
414
+ args: tuple[Any, ...] = ()
415
+ kwargs: dict[str, Any] = field(default_factory=dict)
416
+ task_id: str = field(default_factory=lambda: uuid.uuid4().hex[:12])
417
+ timeout: float | None = None
418
+ retries: int = 0
419
+ retry_delay: float = 1.0
420
+ scheduled_at: float | None = None
421
+ recur_interval: float | None = None
422
+ schedule_id: str | None = None
423
+ throttle: float | None = None
424
+ signature: str | None = None
425
+
426
+ # -- Serialization -------------------------------------------------------
427
+
428
+ def _to_dict(self) -> dict[str, Any]:
429
+ """Build the core payload dict (without signature)."""
430
+ d: dict[str, Any] = {
431
+ "id": self.task_id,
432
+ "graph": self.graph_json,
433
+ "function": self.function_name,
434
+ "args": list(self.args),
435
+ "kwargs": self.kwargs,
436
+ }
437
+ if self.timeout is not None:
438
+ d["timeout"] = self.timeout
439
+ if self.retries:
440
+ d["retries"] = self.retries
441
+ if self.retry_delay != 1.0:
442
+ d["retry_delay"] = self.retry_delay
443
+ if self.scheduled_at is not None:
444
+ d["scheduled_at"] = self.scheduled_at
445
+ if self.recur_interval is not None:
446
+ d["recur_interval"] = self.recur_interval
447
+ if self.schedule_id is not None:
448
+ d["schedule_id"] = self.schedule_id
449
+ if self.throttle is not None:
450
+ d["throttle"] = self.throttle
451
+ return d
452
+
453
+ def to_json(self, *, signing_key: bytes | None = None) -> str:
454
+ """Serialize the task envelope to a JSON string.
455
+
456
+ Parameters
457
+ ----------
458
+ signing_key
459
+ When provided, the payload is HMAC-SHA256 signed and the
460
+ signature is included in the envelope. Workers that hold
461
+ the same key can verify it with :meth:`from_signed_json`.
462
+ """
463
+ d = self._to_dict()
464
+ if signing_key is not None:
465
+ payload = json.dumps(d, cls=_TaskEncoder, separators=(",", ":"), sort_keys=True)
466
+ d["signature"] = compute_signature(payload, signing_key)
467
+ return json.dumps(d, cls=_TaskEncoder)
468
+
469
+ @classmethod
470
+ def from_json(
471
+ cls,
472
+ json_str: str | bytes,
473
+ *,
474
+ signing_key: bytes | None = None,
475
+ ) -> Self:
476
+ """Deserialize a task envelope from a JSON string.
477
+
478
+ Parameters
479
+ ----------
480
+ signing_key
481
+ When provided **and** the envelope contains a ``signature``
482
+ field, the signature is verified. If verification fails,
483
+ :class:`~offwork.core.errors.SignatureError` is raised.
484
+ Unsigned tasks are accepted when *signing_key* is ``None``.
485
+
486
+ Raises
487
+ ------
488
+ SignatureError
489
+ If the signature is present but invalid, or if *signing_key*
490
+ is provided but the envelope has no signature.
491
+ """
492
+ data = json.loads(json_str)
493
+ sig = data.pop("signature", None)
494
+
495
+ if signing_key is not None:
496
+ if not sig:
497
+ raise SignatureError(
498
+ "Task is unsigned but signing is enabled — "
499
+ "rejecting unauthenticated task"
500
+ )
501
+ # Re-serialize without signature for verification
502
+ payload = json.dumps(data, cls=_TaskEncoder, separators=(",", ":"), sort_keys=True)
503
+ if not verify_signature(payload, sig, signing_key):
504
+ raise SignatureError("Task signature verification failed")
505
+
506
+ return cls(
507
+ graph_json=data["graph"],
508
+ function_name=data["function"],
509
+ args=tuple(data.get("args", ())),
510
+ kwargs=data.get("kwargs", {}),
511
+ task_id=data.get("id", uuid.uuid4().hex[:12]),
512
+ timeout=data.get("timeout"),
513
+ retries=data.get("retries", 0),
514
+ retry_delay=data.get("retry_delay", 1.0),
515
+ scheduled_at=data.get("scheduled_at"),
516
+ recur_interval=data.get("recur_interval"),
517
+ schedule_id=data.get("schedule_id"),
518
+ throttle=data.get("throttle"),
519
+ signature=sig or None, # normalise empty string to None
520
+ )
offwork/core/token.py ADDED
@@ -0,0 +1,184 @@
1
+ """Pre-shared token for automated task signing.
2
+
3
+ Tokens provide an alternative to the interactive PIN-based pairing
4
+ protocol (see :mod:`offwork.core.pairing`). A token is a random
5
+ 32-byte secret that can be generated offline, stored in CI secrets or
6
+ configuration management, and distributed to clients and workers
7
+ independently — no real-time pairing step is required.
8
+
9
+ Once both sides share the same token, the signing and verification
10
+ flow is identical to the pairing-based approach: the client signs
11
+ every task with HMAC-SHA256 and the worker verifies the signature
12
+ before execution.
13
+
14
+ Key resolution order (highest priority first):
15
+
16
+ 1. ``OFFWORK_SIGNING_TOKEN`` environment variable (hex-encoded)
17
+ 2. ``~/.offwork/token`` file (hex-encoded)
18
+ 3. ``~/.offwork/{client,worker}.key`` file (raw bytes, from pairing)
19
+
20
+ All primitives are stdlib-only.
21
+ """
22
+
23
+ import os
24
+ import logging
25
+ from pathlib import Path
26
+
27
+ from offwork.core.signing import derive_key
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ # Environment variable for token distribution
32
+ _TOKEN_ENV_VAR = "OFFWORK_SIGNING_TOKEN"
33
+
34
+ # File persistence
35
+ _DEFAULT_KEY_DIR = Path.home() / ".offwork"
36
+ _TOKEN_FILE = "token"
37
+
38
+ # Token size in bytes
39
+ _TOKEN_BYTES = 32
40
+
41
+
42
+ def generate_token() -> str:
43
+ """Generate a random signing token and return it as a hex string.
44
+
45
+ The token is 32 bytes of cryptographically secure random data,
46
+ encoded as a 64-character hexadecimal string.
47
+ """
48
+ return os.urandom(_TOKEN_BYTES).hex()
49
+
50
+
51
+ def save_token(
52
+ token_hex: str,
53
+ key_dir: Path | None = None,
54
+ ) -> Path:
55
+ """Persist a hex-encoded token to ``~/.offwork/token``.
56
+
57
+ Parameters
58
+ ----------
59
+ token_hex
60
+ The 64-character hex-encoded token string.
61
+ key_dir
62
+ Override the default ``~/.offwork`` directory.
63
+
64
+ Returns
65
+ -------
66
+ Path
67
+ The file that was written.
68
+
69
+ Raises
70
+ ------
71
+ ValueError
72
+ If *token_hex* is not a valid 64-character hex string.
73
+ """
74
+ _validate_token_hex(token_hex)
75
+ d = _ensure_key_dir(key_dir)
76
+ path = d / _TOKEN_FILE
77
+ path.write_text(token_hex + "\n")
78
+ path.chmod(0o600)
79
+ logger.info("Saved token to %s", path)
80
+ return path
81
+
82
+
83
+ def load_token(key_dir: Path | None = None) -> str | None:
84
+ """Load a hex-encoded token from the environment or disk.
85
+
86
+ Resolution order:
87
+
88
+ 1. ``OFFWORK_SIGNING_TOKEN`` environment variable
89
+ 2. ``~/.offwork/token`` file
90
+
91
+ Returns ``None`` when no token is found.
92
+ """
93
+ # 1. Environment variable
94
+ env_val = os.environ.get(_TOKEN_ENV_VAR)
95
+ if env_val is not None:
96
+ env_val = env_val.strip()
97
+ if _is_valid_token_hex(env_val):
98
+ return env_val
99
+ logger.warning(
100
+ "%s is set but contains an invalid token "
101
+ "(expected 64 hex characters)",
102
+ _TOKEN_ENV_VAR,
103
+ )
104
+ return None
105
+
106
+ # 2. File on disk
107
+ d = _ensure_key_dir(key_dir)
108
+ path = d / _TOKEN_FILE
109
+ if not path.exists():
110
+ return None
111
+ content = path.read_text().strip()
112
+ if _is_valid_token_hex(content):
113
+ return content
114
+ logger.warning("Invalid token file %s (expected 64 hex characters)", path)
115
+ return None
116
+
117
+
118
+ def clear_token(key_dir: Path | None = None) -> bool:
119
+ """Delete a saved token file. Returns ``True`` if a file was removed."""
120
+ d = _ensure_key_dir(key_dir)
121
+ path = d / _TOKEN_FILE
122
+ if path.exists():
123
+ path.unlink()
124
+ logger.info("Removed token %s", path)
125
+ return True
126
+ return False
127
+
128
+
129
+ def resolve_signing_key(role: str, key_dir: Path | None = None) -> bytes | None:
130
+ """Resolve the HMAC signing key using the unified precedence order.
131
+
132
+ Checks token sources first, then falls back to pairing keys:
133
+
134
+ 1. ``OFFWORK_SIGNING_TOKEN`` environment variable
135
+ 2. ``~/.offwork/token`` file
136
+ 3. ``~/.offwork/{role}.key`` (from pairing)
137
+
138
+ Returns a derived 32-byte HMAC key, or ``None`` if no key material
139
+ is found.
140
+ """
141
+ from offwork.core.pairing import load_shared_key
142
+
143
+ # Try token first
144
+ token_hex = load_token(key_dir)
145
+ if token_hex is not None:
146
+ raw = bytes.fromhex(token_hex)
147
+ return derive_key(raw)
148
+
149
+ # Fall back to pairing key
150
+ raw_key = load_shared_key(role, key_dir)
151
+ if raw_key is not None:
152
+ return derive_key(raw_key)
153
+
154
+ return None
155
+
156
+
157
+ # -- Helpers ----------------------------------------------------------------
158
+
159
+
160
+ def _ensure_key_dir(key_dir: Path | None = None) -> Path:
161
+ """Return the key directory, creating it if necessary."""
162
+ d = key_dir or _DEFAULT_KEY_DIR
163
+ d.mkdir(parents=True, exist_ok=True)
164
+ return d
165
+
166
+
167
+ def _is_valid_token_hex(s: str) -> bool:
168
+ """Return ``True`` if *s* is a valid 64-character hex string."""
169
+ if len(s) != _TOKEN_BYTES * 2:
170
+ return False
171
+ try:
172
+ bytes.fromhex(s)
173
+ except ValueError:
174
+ return False
175
+ return True
176
+
177
+
178
+ def _validate_token_hex(token_hex: str) -> None:
179
+ """Raise ``ValueError`` if *token_hex* is invalid."""
180
+ if not _is_valid_token_hex(token_hex):
181
+ raise ValueError(
182
+ f"Invalid token: expected a 64-character hex string, "
183
+ f"got {len(token_hex)} characters"
184
+ )