hatchet-sdk 0.41.0__py3-none-any.whl → 0.42.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hatchet-sdk might be problematic. Click here for more details.
- hatchet_sdk/clients/admin.py +10 -8
- hatchet_sdk/clients/dispatcher/action_listener.py +1 -1
- hatchet_sdk/clients/dispatcher/dispatcher.py +2 -2
- hatchet_sdk/clients/rest/tenacity_utils.py +6 -1
- hatchet_sdk/context/context.py +83 -46
- hatchet_sdk/context/worker_context.py +1 -1
- hatchet_sdk/contracts/dispatcher_pb2.py +71 -67
- hatchet_sdk/contracts/dispatcher_pb2.pyi +29 -2
- hatchet_sdk/contracts/workflows_pb2.py +42 -40
- hatchet_sdk/contracts/workflows_pb2.pyi +22 -6
- hatchet_sdk/hatchet.py +44 -32
- hatchet_sdk/utils/backoff.py +1 -1
- hatchet_sdk/utils/serialization.py +4 -1
- hatchet_sdk/utils/tracing.py +7 -4
- hatchet_sdk/utils/types.py +8 -0
- hatchet_sdk/utils/typing.py +9 -0
- hatchet_sdk/v2/callable.py +1 -0
- hatchet_sdk/worker/action_listener_process.py +7 -9
- hatchet_sdk/worker/runner/run_loop_manager.py +15 -9
- hatchet_sdk/worker/runner/runner.py +57 -36
- hatchet_sdk/worker/worker.py +96 -59
- hatchet_sdk/workflow.py +84 -26
- {hatchet_sdk-0.41.0.dist-info → hatchet_sdk-0.42.1.dist-info}/METADATA +1 -1
- {hatchet_sdk-0.41.0.dist-info → hatchet_sdk-0.42.1.dist-info}/RECORD +26 -24
- {hatchet_sdk-0.41.0.dist-info → hatchet_sdk-0.42.1.dist-info}/entry_points.txt +2 -0
- {hatchet_sdk-0.41.0.dist-info → hatchet_sdk-0.42.1.dist-info}/WHEEL +0 -0
|
@@ -8,9 +8,10 @@ from concurrent.futures import ThreadPoolExecutor
|
|
|
8
8
|
from enum import Enum
|
|
9
9
|
from multiprocessing import Queue
|
|
10
10
|
from threading import Thread, current_thread
|
|
11
|
-
from typing import Any, Callable, Dict
|
|
11
|
+
from typing import Any, Callable, Dict, Literal, Type, TypeVar, cast, overload
|
|
12
12
|
|
|
13
13
|
from opentelemetry.trace import StatusCode
|
|
14
|
+
from pydantic import BaseModel
|
|
14
15
|
|
|
15
16
|
from hatchet_sdk.client import new_client_raw
|
|
16
17
|
from hatchet_sdk.clients.admin import new_admin
|
|
@@ -18,9 +19,9 @@ from hatchet_sdk.clients.dispatcher.action_listener import Action
|
|
|
18
19
|
from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher
|
|
19
20
|
from hatchet_sdk.clients.run_event_listener import new_listener
|
|
20
21
|
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
|
|
21
|
-
from hatchet_sdk.context import Context
|
|
22
|
+
from hatchet_sdk.context import Context # type: ignore[attr-defined]
|
|
22
23
|
from hatchet_sdk.context.worker_context import WorkerContext
|
|
23
|
-
from hatchet_sdk.contracts.dispatcher_pb2 import (
|
|
24
|
+
from hatchet_sdk.contracts.dispatcher_pb2 import ( # type: ignore[attr-defined]
|
|
24
25
|
GROUP_KEY_EVENT_TYPE_COMPLETED,
|
|
25
26
|
GROUP_KEY_EVENT_TYPE_FAILED,
|
|
26
27
|
GROUP_KEY_EVENT_TYPE_STARTED,
|
|
@@ -32,6 +33,7 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
|
|
|
32
33
|
from hatchet_sdk.loader import ClientConfig
|
|
33
34
|
from hatchet_sdk.logger import logger
|
|
34
35
|
from hatchet_sdk.utils.tracing import create_tracer, parse_carrier_from_metadata
|
|
36
|
+
from hatchet_sdk.utils.types import WorkflowValidator
|
|
35
37
|
from hatchet_sdk.v2.callable import DurableContext
|
|
36
38
|
from hatchet_sdk.worker.action_listener_process import ActionEvent
|
|
37
39
|
from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr
|
|
@@ -48,11 +50,12 @@ class Runner:
|
|
|
48
50
|
def __init__(
|
|
49
51
|
self,
|
|
50
52
|
name: str,
|
|
51
|
-
event_queue: Queue,
|
|
53
|
+
event_queue: "Queue[Any]",
|
|
52
54
|
max_runs: int | None = None,
|
|
53
55
|
handle_kill: bool = True,
|
|
54
56
|
action_registry: dict[str, Callable[..., Any]] = {},
|
|
55
|
-
|
|
57
|
+
validator_registry: dict[str, WorkflowValidator] = {},
|
|
58
|
+
config: ClientConfig = ClientConfig(),
|
|
56
59
|
labels: dict[str, str | int] = {},
|
|
57
60
|
):
|
|
58
61
|
# We store the config so we can dynamically create clients for the dispatcher client.
|
|
@@ -60,9 +63,10 @@ class Runner:
|
|
|
60
63
|
self.client = new_client_raw(config)
|
|
61
64
|
self.name = self.client.config.namespace + name
|
|
62
65
|
self.max_runs = max_runs
|
|
63
|
-
self.tasks:
|
|
64
|
-
self.contexts:
|
|
66
|
+
self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures
|
|
67
|
+
self.contexts: dict[str, Context] = {} # Store run ids and contexts
|
|
65
68
|
self.action_registry: dict[str, Callable[..., Any]] = action_registry
|
|
69
|
+
self.validator_registry = validator_registry
|
|
66
70
|
|
|
67
71
|
self.event_queue = event_queue
|
|
68
72
|
|
|
@@ -89,7 +93,7 @@ class Runner:
|
|
|
89
93
|
def create_workflow_run_url(self, action: Action) -> str:
|
|
90
94
|
return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}"
|
|
91
95
|
|
|
92
|
-
def run(self, action: Action):
|
|
96
|
+
def run(self, action: Action) -> None:
|
|
93
97
|
ctx = parse_carrier_from_metadata(action.additional_metadata)
|
|
94
98
|
|
|
95
99
|
with self.otel_tracer.start_as_current_span(
|
|
@@ -122,8 +126,8 @@ class Runner:
|
|
|
122
126
|
span.add_event(log)
|
|
123
127
|
logger.error(log)
|
|
124
128
|
|
|
125
|
-
def step_run_callback(self, action: Action):
|
|
126
|
-
def inner_callback(task: asyncio.Task):
|
|
129
|
+
def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
|
|
130
|
+
def inner_callback(task: asyncio.Task[Any]) -> None:
|
|
127
131
|
self.cleanup_run_id(action.step_run_id)
|
|
128
132
|
|
|
129
133
|
errored = False
|
|
@@ -164,8 +168,10 @@ class Runner:
|
|
|
164
168
|
|
|
165
169
|
return inner_callback
|
|
166
170
|
|
|
167
|
-
def group_key_run_callback(
|
|
168
|
-
|
|
171
|
+
def group_key_run_callback(
|
|
172
|
+
self, action: Action
|
|
173
|
+
) -> Callable[[asyncio.Task[Any]], None]:
|
|
174
|
+
def inner_callback(task: asyncio.Task[Any]) -> None:
|
|
169
175
|
self.cleanup_run_id(action.get_group_key_run_id)
|
|
170
176
|
|
|
171
177
|
errored = False
|
|
@@ -204,7 +210,10 @@ class Runner:
|
|
|
204
210
|
|
|
205
211
|
return inner_callback
|
|
206
212
|
|
|
207
|
-
|
|
213
|
+
## TODO: Stricter type hinting here
|
|
214
|
+
def thread_action_func(
|
|
215
|
+
self, context: Context, action_func: Callable[..., Any], action: Action
|
|
216
|
+
) -> Any:
|
|
208
217
|
if action.step_run_id is not None and action.step_run_id != "":
|
|
209
218
|
self.threads[action.step_run_id] = current_thread()
|
|
210
219
|
elif (
|
|
@@ -215,10 +224,15 @@ class Runner:
|
|
|
215
224
|
|
|
216
225
|
return action_func(context)
|
|
217
226
|
|
|
227
|
+
## TODO: Stricter type hinting here
|
|
218
228
|
# We wrap all actions in an async func
|
|
219
229
|
async def async_wrapped_action_func(
|
|
220
|
-
self,
|
|
221
|
-
|
|
230
|
+
self,
|
|
231
|
+
context: Context,
|
|
232
|
+
action_func: Callable[..., Any],
|
|
233
|
+
action: Action,
|
|
234
|
+
run_id: str,
|
|
235
|
+
) -> Any:
|
|
222
236
|
wr.set(context.workflow_run_id())
|
|
223
237
|
sr.set(context.step_run_id)
|
|
224
238
|
|
|
@@ -240,9 +254,7 @@ class Runner:
|
|
|
240
254
|
)
|
|
241
255
|
|
|
242
256
|
loop = asyncio.get_event_loop()
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
return res
|
|
257
|
+
return await loop.run_in_executor(self.thread_pool, pfunc)
|
|
246
258
|
except Exception as e:
|
|
247
259
|
logger.error(
|
|
248
260
|
errorWithTraceback(
|
|
@@ -254,7 +266,7 @@ class Runner:
|
|
|
254
266
|
finally:
|
|
255
267
|
self.cleanup_run_id(run_id)
|
|
256
268
|
|
|
257
|
-
def cleanup_run_id(self, run_id: str):
|
|
269
|
+
def cleanup_run_id(self, run_id: str | None) -> None:
|
|
258
270
|
if run_id in self.tasks:
|
|
259
271
|
del self.tasks[run_id]
|
|
260
272
|
|
|
@@ -267,7 +279,7 @@ class Runner:
|
|
|
267
279
|
def create_context(
|
|
268
280
|
self, action: Action, action_func: Callable[..., Any] | None
|
|
269
281
|
) -> Context | DurableContext:
|
|
270
|
-
if hasattr(action_func, "durable") and action_func
|
|
282
|
+
if hasattr(action_func, "durable") and getattr(action_func, "durable"):
|
|
271
283
|
return DurableContext(
|
|
272
284
|
action,
|
|
273
285
|
self.dispatcher_client,
|
|
@@ -278,6 +290,7 @@ class Runner:
|
|
|
278
290
|
self.workflow_run_event_listener,
|
|
279
291
|
self.worker_context,
|
|
280
292
|
self.client.config.namespace,
|
|
293
|
+
validator_registry=self.validator_registry,
|
|
281
294
|
)
|
|
282
295
|
|
|
283
296
|
return Context(
|
|
@@ -290,9 +303,10 @@ class Runner:
|
|
|
290
303
|
self.workflow_run_event_listener,
|
|
291
304
|
self.worker_context,
|
|
292
305
|
self.client.config.namespace,
|
|
306
|
+
validator_registry=self.validator_registry,
|
|
293
307
|
)
|
|
294
308
|
|
|
295
|
-
async def handle_start_step_run(self, action: Action):
|
|
309
|
+
async def handle_start_step_run(self, action: Action) -> None:
|
|
296
310
|
with self.otel_tracer.start_as_current_span(
|
|
297
311
|
f"hatchet.worker.handle_start_step_run.{action.step_id}",
|
|
298
312
|
) as span:
|
|
@@ -336,7 +350,7 @@ class Runner:
|
|
|
336
350
|
|
|
337
351
|
span.add_event("Finished step run")
|
|
338
352
|
|
|
339
|
-
async def handle_start_group_key_run(self, action: Action):
|
|
353
|
+
async def handle_start_group_key_run(self, action: Action) -> None:
|
|
340
354
|
with self.otel_tracer.start_as_current_span(
|
|
341
355
|
f"hatchet.worker.handle_start_step_run.{action.step_id}"
|
|
342
356
|
) as span:
|
|
@@ -353,6 +367,7 @@ class Runner:
|
|
|
353
367
|
self.worker_context,
|
|
354
368
|
self.client.config.namespace,
|
|
355
369
|
)
|
|
370
|
+
|
|
356
371
|
self.contexts[action.get_group_key_run_id] = context
|
|
357
372
|
|
|
358
373
|
# Find the corresponding action function from the registry
|
|
@@ -387,18 +402,18 @@ class Runner:
|
|
|
387
402
|
|
|
388
403
|
span.add_event("Finished group key run")
|
|
389
404
|
|
|
390
|
-
def force_kill_thread(self, thread):
|
|
405
|
+
def force_kill_thread(self, thread: Thread) -> None:
|
|
391
406
|
"""Terminate a python threading.Thread."""
|
|
392
407
|
try:
|
|
393
408
|
if not thread.is_alive():
|
|
394
409
|
return
|
|
395
410
|
|
|
396
|
-
|
|
411
|
+
ident = cast(int, thread.ident)
|
|
412
|
+
|
|
413
|
+
logger.info(f"Forcefully terminating thread {ident}")
|
|
397
414
|
|
|
398
415
|
exc = ctypes.py_object(SystemExit)
|
|
399
|
-
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
|
|
400
|
-
ctypes.c_long(thread.ident), exc
|
|
401
|
-
)
|
|
416
|
+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
|
|
402
417
|
if res == 0:
|
|
403
418
|
raise ValueError("Invalid thread ID")
|
|
404
419
|
elif res != 1:
|
|
@@ -408,7 +423,7 @@ class Runner:
|
|
|
408
423
|
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
|
|
409
424
|
raise SystemError("PyThreadState_SetAsyncExc failed")
|
|
410
425
|
|
|
411
|
-
logger.info(f"Successfully terminated thread {
|
|
426
|
+
logger.info(f"Successfully terminated thread {ident}")
|
|
412
427
|
|
|
413
428
|
# Immediately add a new thread to the thread pool, because we've actually killed a worker
|
|
414
429
|
# in the ThreadPoolExecutor
|
|
@@ -416,7 +431,7 @@ class Runner:
|
|
|
416
431
|
except Exception as e:
|
|
417
432
|
logger.exception(f"Failed to terminate thread: {e}")
|
|
418
433
|
|
|
419
|
-
async def handle_cancel_action(self, run_id: str):
|
|
434
|
+
async def handle_cancel_action(self, run_id: str) -> None:
|
|
420
435
|
with self.otel_tracer.start_as_current_span(
|
|
421
436
|
"hatchet.worker.handle_cancel_action"
|
|
422
437
|
) as span:
|
|
@@ -427,7 +442,9 @@ class Runner:
|
|
|
427
442
|
# call cancel to signal the context to stop
|
|
428
443
|
if run_id in self.contexts:
|
|
429
444
|
context = self.contexts.get(run_id)
|
|
430
|
-
|
|
445
|
+
|
|
446
|
+
if context:
|
|
447
|
+
context.cancel()
|
|
431
448
|
|
|
432
449
|
await asyncio.sleep(1)
|
|
433
450
|
|
|
@@ -449,16 +466,20 @@ class Runner:
|
|
|
449
466
|
span.add_event(f"Finished cancelling run id: {run_id}")
|
|
450
467
|
|
|
451
468
|
def serialize_output(self, output: Any) -> str:
|
|
452
|
-
|
|
469
|
+
|
|
470
|
+
if isinstance(output, BaseModel):
|
|
471
|
+
return output.model_dump_json()
|
|
472
|
+
|
|
453
473
|
if output is not None:
|
|
454
474
|
try:
|
|
455
|
-
|
|
475
|
+
return json.dumps(output)
|
|
456
476
|
except Exception as e:
|
|
457
477
|
logger.error(f"Could not serialize output: {e}")
|
|
458
|
-
|
|
459
|
-
|
|
478
|
+
return str(output)
|
|
479
|
+
|
|
480
|
+
return ""
|
|
460
481
|
|
|
461
|
-
async def wait_for_tasks(self):
|
|
482
|
+
async def wait_for_tasks(self) -> None:
|
|
462
483
|
running = len(self.tasks.keys())
|
|
463
484
|
while running > 0:
|
|
464
485
|
logger.info(f"waiting for {running} tasks to finish...")
|
|
@@ -466,6 +487,6 @@ class Runner:
|
|
|
466
487
|
running = len(self.tasks.keys())
|
|
467
488
|
|
|
468
489
|
|
|
469
|
-
def errorWithTraceback(message: str, e: Exception):
|
|
490
|
+
def errorWithTraceback(message: str, e: Exception) -> str:
|
|
470
491
|
trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
471
492
|
return f"{message}\n{trace}"
|
hatchet_sdk/worker/worker.py
CHANGED
|
@@ -1,22 +1,32 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import multiprocessing
|
|
3
|
+
import multiprocessing.context
|
|
3
4
|
import os
|
|
4
5
|
import signal
|
|
5
6
|
import sys
|
|
7
|
+
from concurrent.futures import Future
|
|
6
8
|
from dataclasses import dataclass, field
|
|
7
9
|
from enum import Enum
|
|
8
|
-
from multiprocessing import
|
|
9
|
-
from
|
|
10
|
+
from multiprocessing import Queue
|
|
11
|
+
from multiprocessing.process import BaseProcess
|
|
12
|
+
from types import FrameType
|
|
13
|
+
from typing import Any, Callable, TypeVar, get_type_hints
|
|
10
14
|
|
|
15
|
+
from hatchet_sdk import Context
|
|
11
16
|
from hatchet_sdk.client import Client, new_client_raw
|
|
12
|
-
from hatchet_sdk.
|
|
13
|
-
|
|
17
|
+
from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
|
|
18
|
+
CreateWorkflowVersionOpts,
|
|
19
|
+
)
|
|
14
20
|
from hatchet_sdk.loader import ClientConfig
|
|
15
21
|
from hatchet_sdk.logger import logger
|
|
22
|
+
from hatchet_sdk.utils.types import WorkflowValidator
|
|
16
23
|
from hatchet_sdk.v2.callable import HatchetCallable
|
|
24
|
+
from hatchet_sdk.v2.concurrency import ConcurrencyFunction
|
|
17
25
|
from hatchet_sdk.worker.action_listener_process import worker_action_listener_process
|
|
18
26
|
from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager
|
|
19
|
-
from hatchet_sdk.workflow import
|
|
27
|
+
from hatchet_sdk.workflow import WorkflowInterface
|
|
28
|
+
|
|
29
|
+
T = TypeVar("T")
|
|
20
30
|
|
|
21
31
|
|
|
22
32
|
class WorkerStatus(Enum):
|
|
@@ -28,46 +38,60 @@ class WorkerStatus(Enum):
|
|
|
28
38
|
|
|
29
39
|
@dataclass
|
|
30
40
|
class WorkerStartOptions:
|
|
31
|
-
loop: asyncio.AbstractEventLoop = field(default=None)
|
|
41
|
+
loop: asyncio.AbstractEventLoop | None = field(default=None)
|
|
32
42
|
|
|
33
43
|
|
|
34
|
-
@dataclass
|
|
35
44
|
class Worker:
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
45
|
+
def __init__(
|
|
46
|
+
self,
|
|
47
|
+
name: str,
|
|
48
|
+
config: ClientConfig = ClientConfig(),
|
|
49
|
+
max_runs: int | None = None,
|
|
50
|
+
labels: dict[str, str | int] = {},
|
|
51
|
+
debug: bool = False,
|
|
52
|
+
owned_loop: bool = True,
|
|
53
|
+
handle_kill: bool = True,
|
|
54
|
+
) -> None:
|
|
55
|
+
self.name = name
|
|
56
|
+
self.config = config
|
|
57
|
+
self.max_runs = max_runs
|
|
58
|
+
self.debug = debug
|
|
59
|
+
self.labels = labels
|
|
60
|
+
self.handle_kill = handle_kill
|
|
61
|
+
self.owned_loop = owned_loop
|
|
62
|
+
|
|
63
|
+
self.client: Client
|
|
64
|
+
|
|
65
|
+
self.action_registry: dict[str, Callable[[Context], T]] = {}
|
|
66
|
+
self.validator_registry: dict[str, WorkflowValidator] = {}
|
|
67
|
+
|
|
68
|
+
self.killing: bool = False
|
|
69
|
+
self._status: WorkerStatus
|
|
70
|
+
|
|
71
|
+
self.action_listener_process: BaseProcess
|
|
72
|
+
self.action_listener_health_check: asyncio.Task[Any]
|
|
73
|
+
self.action_runner: WorkerActionRunLoopManager
|
|
74
|
+
|
|
75
|
+
self.ctx = multiprocessing.get_context("spawn")
|
|
76
|
+
|
|
77
|
+
self.action_queue: "Queue[Any]" = self.ctx.Queue()
|
|
78
|
+
self.event_queue: "Queue[Any]" = self.ctx.Queue()
|
|
79
|
+
|
|
80
|
+
self.loop: asyncio.AbstractEventLoop
|
|
81
|
+
|
|
62
82
|
self.client = new_client_raw(self.config, self.debug)
|
|
63
83
|
self.name = self.client.config.namespace + self.name
|
|
64
|
-
if self.owned_loop:
|
|
65
|
-
self._setup_signal_handlers()
|
|
66
84
|
|
|
67
|
-
|
|
85
|
+
self._setup_signal_handlers()
|
|
86
|
+
|
|
87
|
+
def register_function(
|
|
88
|
+
self, action: str, func: HatchetCallable[Any] | ConcurrencyFunction
|
|
89
|
+
) -> None:
|
|
68
90
|
self.action_registry[action] = func
|
|
69
91
|
|
|
70
|
-
def register_workflow_from_opts(
|
|
92
|
+
def register_workflow_from_opts(
|
|
93
|
+
self, name: str, opts: CreateWorkflowVersionOpts
|
|
94
|
+
) -> None:
|
|
71
95
|
try:
|
|
72
96
|
self.client.admin.put_workflow(opts.name, opts)
|
|
73
97
|
except Exception as e:
|
|
@@ -75,7 +99,7 @@ class Worker:
|
|
|
75
99
|
logger.error(e)
|
|
76
100
|
sys.exit(1)
|
|
77
101
|
|
|
78
|
-
def register_workflow(self, workflow:
|
|
102
|
+
def register_workflow(self, workflow: WorkflowInterface) -> None:
|
|
79
103
|
namespace = self.client.config.namespace
|
|
80
104
|
|
|
81
105
|
try:
|
|
@@ -87,24 +111,30 @@ class Worker:
|
|
|
87
111
|
logger.error(e)
|
|
88
112
|
sys.exit(1)
|
|
89
113
|
|
|
90
|
-
def create_action_function(
|
|
91
|
-
|
|
114
|
+
def create_action_function(
|
|
115
|
+
action_func: Callable[..., T]
|
|
116
|
+
) -> Callable[[Context], T]:
|
|
117
|
+
def action_function(context: Context) -> T:
|
|
92
118
|
return action_func(workflow, context)
|
|
93
119
|
|
|
94
120
|
if asyncio.iscoroutinefunction(action_func):
|
|
95
|
-
action_function
|
|
121
|
+
setattr(action_function, "is_coroutine", True)
|
|
96
122
|
else:
|
|
97
|
-
action_function
|
|
123
|
+
setattr(action_function, "is_coroutine", False)
|
|
98
124
|
|
|
99
125
|
return action_function
|
|
100
126
|
|
|
101
127
|
for action_name, action_func in workflow.get_actions(namespace):
|
|
102
128
|
self.action_registry[action_name] = create_action_function(action_func)
|
|
129
|
+
return_type = get_type_hints(action_func).get("return")
|
|
130
|
+
self.validator_registry[action_name] = WorkflowValidator(
|
|
131
|
+
workflow_input=workflow.input_validator, step_output=return_type
|
|
132
|
+
)
|
|
103
133
|
|
|
104
134
|
def status(self) -> WorkerStatus:
|
|
105
135
|
return self._status
|
|
106
136
|
|
|
107
|
-
def setup_loop(self, loop: asyncio.AbstractEventLoop = None):
|
|
137
|
+
def setup_loop(self, loop: asyncio.AbstractEventLoop | None = None) -> bool:
|
|
108
138
|
try:
|
|
109
139
|
loop = loop or asyncio.get_running_loop()
|
|
110
140
|
self.loop = loop
|
|
@@ -118,17 +148,22 @@ class Worker:
|
|
|
118
148
|
created_loop = True
|
|
119
149
|
return created_loop
|
|
120
150
|
|
|
121
|
-
def start(
|
|
151
|
+
def start(
|
|
152
|
+
self, options: WorkerStartOptions = WorkerStartOptions()
|
|
153
|
+
) -> Future[asyncio.Task[Any] | None]:
|
|
122
154
|
self.owned_loop = self.setup_loop(options.loop)
|
|
155
|
+
|
|
123
156
|
f = asyncio.run_coroutine_threadsafe(
|
|
124
157
|
self.async_start(options, _from_start=True), self.loop
|
|
125
158
|
)
|
|
159
|
+
|
|
126
160
|
# start the loop and wait until its closed
|
|
127
161
|
if self.owned_loop:
|
|
128
162
|
self.loop.run_forever()
|
|
129
163
|
|
|
130
164
|
if self.handle_kill:
|
|
131
165
|
sys.exit(0)
|
|
166
|
+
|
|
132
167
|
return f
|
|
133
168
|
|
|
134
169
|
## Start methods
|
|
@@ -136,7 +171,7 @@ class Worker:
|
|
|
136
171
|
self,
|
|
137
172
|
options: WorkerStartOptions = WorkerStartOptions(),
|
|
138
173
|
_from_start: bool = False,
|
|
139
|
-
):
|
|
174
|
+
) -> Any | None:
|
|
140
175
|
main_pid = os.getpid()
|
|
141
176
|
logger.info("------------------------------------------")
|
|
142
177
|
logger.info("STARTING HATCHET...")
|
|
@@ -148,25 +183,28 @@ class Worker:
|
|
|
148
183
|
logger.error(
|
|
149
184
|
"no actions registered, register workflows or actions before starting worker"
|
|
150
185
|
)
|
|
151
|
-
return
|
|
186
|
+
return None
|
|
152
187
|
|
|
153
188
|
# non blocking setup
|
|
154
189
|
if not _from_start:
|
|
155
190
|
self.setup_loop(options.loop)
|
|
156
191
|
|
|
157
192
|
self.action_listener_process = self._start_listener()
|
|
193
|
+
|
|
158
194
|
self.action_runner = self._run_action_runner()
|
|
195
|
+
|
|
159
196
|
self.action_listener_health_check = self.loop.create_task(
|
|
160
197
|
self._check_listener_health()
|
|
161
198
|
)
|
|
162
199
|
|
|
163
200
|
return await self.action_listener_health_check
|
|
164
201
|
|
|
165
|
-
def _run_action_runner(self):
|
|
202
|
+
def _run_action_runner(self) -> WorkerActionRunLoopManager:
|
|
166
203
|
# Retrieve the shared queue
|
|
167
|
-
|
|
204
|
+
return WorkerActionRunLoopManager(
|
|
168
205
|
self.name,
|
|
169
206
|
self.action_registry,
|
|
207
|
+
self.validator_registry,
|
|
170
208
|
self.max_runs,
|
|
171
209
|
self.config,
|
|
172
210
|
self.action_queue,
|
|
@@ -177,10 +215,9 @@ class Worker:
|
|
|
177
215
|
self.labels,
|
|
178
216
|
)
|
|
179
217
|
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
def _start_listener(self):
|
|
218
|
+
def _start_listener(self) -> multiprocessing.context.SpawnProcess:
|
|
183
219
|
action_list = [str(key) for key in self.action_registry.keys()]
|
|
220
|
+
|
|
184
221
|
try:
|
|
185
222
|
process = self.ctx.Process(
|
|
186
223
|
target=worker_action_listener_process,
|
|
@@ -204,7 +241,7 @@ class Worker:
|
|
|
204
241
|
logger.error(f"failed to start action listener: {e}")
|
|
205
242
|
sys.exit(1)
|
|
206
243
|
|
|
207
|
-
async def _check_listener_health(self):
|
|
244
|
+
async def _check_listener_health(self) -> None:
|
|
208
245
|
logger.debug("starting action listener health check...")
|
|
209
246
|
try:
|
|
210
247
|
while not self.killing:
|
|
@@ -224,21 +261,21 @@ class Worker:
|
|
|
224
261
|
logger.error(f"error checking listener health: {e}")
|
|
225
262
|
|
|
226
263
|
## Cleanup methods
|
|
227
|
-
def _setup_signal_handlers(self):
|
|
264
|
+
def _setup_signal_handlers(self) -> None:
|
|
228
265
|
signal.signal(signal.SIGTERM, self._handle_exit_signal)
|
|
229
266
|
signal.signal(signal.SIGINT, self._handle_exit_signal)
|
|
230
267
|
signal.signal(signal.SIGQUIT, self._handle_force_quit_signal)
|
|
231
268
|
|
|
232
|
-
def _handle_exit_signal(self, signum, frame):
|
|
269
|
+
def _handle_exit_signal(self, signum: int, frame: FrameType | None) -> None:
|
|
233
270
|
sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT"
|
|
234
271
|
logger.info(f"received signal {sig_name}...")
|
|
235
272
|
self.loop.create_task(self.exit_gracefully())
|
|
236
273
|
|
|
237
|
-
def _handle_force_quit_signal(self, signum, frame):
|
|
274
|
+
def _handle_force_quit_signal(self, signum: int, frame: FrameType | None) -> None:
|
|
238
275
|
logger.info("received SIGQUIT...")
|
|
239
276
|
self.exit_forcefully()
|
|
240
277
|
|
|
241
|
-
async def close(self):
|
|
278
|
+
async def close(self) -> None:
|
|
242
279
|
logger.info(f"closing worker '{self.name}'...")
|
|
243
280
|
self.killing = True
|
|
244
281
|
# self.action_queue.close()
|
|
@@ -249,7 +286,7 @@ class Worker:
|
|
|
249
286
|
|
|
250
287
|
await self.action_listener_health_check
|
|
251
288
|
|
|
252
|
-
async def exit_gracefully(self):
|
|
289
|
+
async def exit_gracefully(self) -> None:
|
|
253
290
|
logger.debug(f"gracefully stopping worker: {self.name}")
|
|
254
291
|
|
|
255
292
|
if self.killing:
|
|
@@ -270,7 +307,7 @@ class Worker:
|
|
|
270
307
|
|
|
271
308
|
logger.info("👋")
|
|
272
309
|
|
|
273
|
-
def exit_forcefully(self):
|
|
310
|
+
def exit_forcefully(self) -> None:
|
|
274
311
|
self.killing = True
|
|
275
312
|
|
|
276
313
|
logger.debug(f"forcefully stopping worker: {self.name}")
|
|
@@ -286,7 +323,7 @@ class Worker:
|
|
|
286
323
|
) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup
|
|
287
324
|
|
|
288
325
|
|
|
289
|
-
def register_on_worker(callable: HatchetCallable, worker: Worker):
|
|
326
|
+
def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None:
|
|
290
327
|
worker.register_function(callable.get_action_name(), callable)
|
|
291
328
|
|
|
292
329
|
if callable.function_on_failure is not None:
|