hatchet-sdk 0.41.0__py3-none-any.whl → 0.42.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

@@ -8,9 +8,10 @@ from concurrent.futures import ThreadPoolExecutor
8
8
  from enum import Enum
9
9
  from multiprocessing import Queue
10
10
  from threading import Thread, current_thread
11
- from typing import Any, Callable, Dict
11
+ from typing import Any, Callable, Dict, Literal, Type, TypeVar, cast, overload
12
12
 
13
13
  from opentelemetry.trace import StatusCode
14
+ from pydantic import BaseModel
14
15
 
15
16
  from hatchet_sdk.client import new_client_raw
16
17
  from hatchet_sdk.clients.admin import new_admin
@@ -18,9 +19,9 @@ from hatchet_sdk.clients.dispatcher.action_listener import Action
18
19
  from hatchet_sdk.clients.dispatcher.dispatcher import new_dispatcher
19
20
  from hatchet_sdk.clients.run_event_listener import new_listener
20
21
  from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
21
- from hatchet_sdk.context import Context
22
+ from hatchet_sdk.context import Context # type: ignore[attr-defined]
22
23
  from hatchet_sdk.context.worker_context import WorkerContext
23
- from hatchet_sdk.contracts.dispatcher_pb2 import (
24
+ from hatchet_sdk.contracts.dispatcher_pb2 import ( # type: ignore[attr-defined]
24
25
  GROUP_KEY_EVENT_TYPE_COMPLETED,
25
26
  GROUP_KEY_EVENT_TYPE_FAILED,
26
27
  GROUP_KEY_EVENT_TYPE_STARTED,
@@ -32,6 +33,7 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
32
33
  from hatchet_sdk.loader import ClientConfig
33
34
  from hatchet_sdk.logger import logger
34
35
  from hatchet_sdk.utils.tracing import create_tracer, parse_carrier_from_metadata
36
+ from hatchet_sdk.utils.types import WorkflowValidator
35
37
  from hatchet_sdk.v2.callable import DurableContext
36
38
  from hatchet_sdk.worker.action_listener_process import ActionEvent
37
39
  from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars, sr, wr
@@ -48,11 +50,12 @@ class Runner:
48
50
  def __init__(
49
51
  self,
50
52
  name: str,
51
- event_queue: Queue,
53
+ event_queue: "Queue[Any]",
52
54
  max_runs: int | None = None,
53
55
  handle_kill: bool = True,
54
56
  action_registry: dict[str, Callable[..., Any]] = {},
55
- config: ClientConfig = {},
57
+ validator_registry: dict[str, WorkflowValidator] = {},
58
+ config: ClientConfig = ClientConfig(),
56
59
  labels: dict[str, str | int] = {},
57
60
  ):
58
61
  # We store the config so we can dynamically create clients for the dispatcher client.
@@ -60,9 +63,10 @@ class Runner:
60
63
  self.client = new_client_raw(config)
61
64
  self.name = self.client.config.namespace + name
62
65
  self.max_runs = max_runs
63
- self.tasks: Dict[str, asyncio.Task] = {} # Store run ids and futures
64
- self.contexts: Dict[str, Context] = {} # Store run ids and contexts
66
+ self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures
67
+ self.contexts: dict[str, Context] = {} # Store run ids and contexts
65
68
  self.action_registry: dict[str, Callable[..., Any]] = action_registry
69
+ self.validator_registry = validator_registry
66
70
 
67
71
  self.event_queue = event_queue
68
72
 
@@ -89,7 +93,7 @@ class Runner:
89
93
  def create_workflow_run_url(self, action: Action) -> str:
90
94
  return f"{self.config.server_url}/workflow-runs/{action.workflow_run_id}?tenant={action.tenant_id}"
91
95
 
92
- def run(self, action: Action):
96
+ def run(self, action: Action) -> None:
93
97
  ctx = parse_carrier_from_metadata(action.additional_metadata)
94
98
 
95
99
  with self.otel_tracer.start_as_current_span(
@@ -122,8 +126,8 @@ class Runner:
122
126
  span.add_event(log)
123
127
  logger.error(log)
124
128
 
125
- def step_run_callback(self, action: Action):
126
- def inner_callback(task: asyncio.Task):
129
+ def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
130
+ def inner_callback(task: asyncio.Task[Any]) -> None:
127
131
  self.cleanup_run_id(action.step_run_id)
128
132
 
129
133
  errored = False
@@ -164,8 +168,10 @@ class Runner:
164
168
 
165
169
  return inner_callback
166
170
 
167
- def group_key_run_callback(self, action: Action):
168
- def inner_callback(task: asyncio.Task):
171
+ def group_key_run_callback(
172
+ self, action: Action
173
+ ) -> Callable[[asyncio.Task[Any]], None]:
174
+ def inner_callback(task: asyncio.Task[Any]) -> None:
169
175
  self.cleanup_run_id(action.get_group_key_run_id)
170
176
 
171
177
  errored = False
@@ -204,7 +210,10 @@ class Runner:
204
210
 
205
211
  return inner_callback
206
212
 
207
- def thread_action_func(self, context, action_func, action: Action):
213
+ ## TODO: Stricter type hinting here
214
+ def thread_action_func(
215
+ self, context: Context, action_func: Callable[..., Any], action: Action
216
+ ) -> Any:
208
217
  if action.step_run_id is not None and action.step_run_id != "":
209
218
  self.threads[action.step_run_id] = current_thread()
210
219
  elif (
@@ -215,10 +224,15 @@ class Runner:
215
224
 
216
225
  return action_func(context)
217
226
 
227
+ ## TODO: Stricter type hinting here
218
228
  # We wrap all actions in an async func
219
229
  async def async_wrapped_action_func(
220
- self, context: Context, action_func, action: Action, run_id: str
221
- ):
230
+ self,
231
+ context: Context,
232
+ action_func: Callable[..., Any],
233
+ action: Action,
234
+ run_id: str,
235
+ ) -> Any:
222
236
  wr.set(context.workflow_run_id())
223
237
  sr.set(context.step_run_id)
224
238
 
@@ -240,9 +254,7 @@ class Runner:
240
254
  )
241
255
 
242
256
  loop = asyncio.get_event_loop()
243
- res = await loop.run_in_executor(self.thread_pool, pfunc)
244
-
245
- return res
257
+ return await loop.run_in_executor(self.thread_pool, pfunc)
246
258
  except Exception as e:
247
259
  logger.error(
248
260
  errorWithTraceback(
@@ -254,7 +266,7 @@ class Runner:
254
266
  finally:
255
267
  self.cleanup_run_id(run_id)
256
268
 
257
- def cleanup_run_id(self, run_id: str):
269
+ def cleanup_run_id(self, run_id: str | None) -> None:
258
270
  if run_id in self.tasks:
259
271
  del self.tasks[run_id]
260
272
 
@@ -267,7 +279,7 @@ class Runner:
267
279
  def create_context(
268
280
  self, action: Action, action_func: Callable[..., Any] | None
269
281
  ) -> Context | DurableContext:
270
- if hasattr(action_func, "durable") and action_func.durable:
282
+ if hasattr(action_func, "durable") and getattr(action_func, "durable"):
271
283
  return DurableContext(
272
284
  action,
273
285
  self.dispatcher_client,
@@ -278,6 +290,7 @@ class Runner:
278
290
  self.workflow_run_event_listener,
279
291
  self.worker_context,
280
292
  self.client.config.namespace,
293
+ validator_registry=self.validator_registry,
281
294
  )
282
295
 
283
296
  return Context(
@@ -290,9 +303,10 @@ class Runner:
290
303
  self.workflow_run_event_listener,
291
304
  self.worker_context,
292
305
  self.client.config.namespace,
306
+ validator_registry=self.validator_registry,
293
307
  )
294
308
 
295
- async def handle_start_step_run(self, action: Action):
309
+ async def handle_start_step_run(self, action: Action) -> None:
296
310
  with self.otel_tracer.start_as_current_span(
297
311
  f"hatchet.worker.handle_start_step_run.{action.step_id}",
298
312
  ) as span:
@@ -336,7 +350,7 @@ class Runner:
336
350
 
337
351
  span.add_event("Finished step run")
338
352
 
339
- async def handle_start_group_key_run(self, action: Action):
353
+ async def handle_start_group_key_run(self, action: Action) -> None:
340
354
  with self.otel_tracer.start_as_current_span(
341
355
  f"hatchet.worker.handle_start_step_run.{action.step_id}"
342
356
  ) as span:
@@ -353,6 +367,7 @@ class Runner:
353
367
  self.worker_context,
354
368
  self.client.config.namespace,
355
369
  )
370
+
356
371
  self.contexts[action.get_group_key_run_id] = context
357
372
 
358
373
  # Find the corresponding action function from the registry
@@ -387,18 +402,18 @@ class Runner:
387
402
 
388
403
  span.add_event("Finished group key run")
389
404
 
390
- def force_kill_thread(self, thread):
405
+ def force_kill_thread(self, thread: Thread) -> None:
391
406
  """Terminate a python threading.Thread."""
392
407
  try:
393
408
  if not thread.is_alive():
394
409
  return
395
410
 
396
- logger.info(f"Forcefully terminating thread {thread.ident}")
411
+ ident = cast(int, thread.ident)
412
+
413
+ logger.info(f"Forcefully terminating thread {ident}")
397
414
 
398
415
  exc = ctypes.py_object(SystemExit)
399
- res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
400
- ctypes.c_long(thread.ident), exc
401
- )
416
+ res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
402
417
  if res == 0:
403
418
  raise ValueError("Invalid thread ID")
404
419
  elif res != 1:
@@ -408,7 +423,7 @@ class Runner:
408
423
  ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
409
424
  raise SystemError("PyThreadState_SetAsyncExc failed")
410
425
 
411
- logger.info(f"Successfully terminated thread {thread.ident}")
426
+ logger.info(f"Successfully terminated thread {ident}")
412
427
 
413
428
  # Immediately add a new thread to the thread pool, because we've actually killed a worker
414
429
  # in the ThreadPoolExecutor
@@ -416,7 +431,7 @@ class Runner:
416
431
  except Exception as e:
417
432
  logger.exception(f"Failed to terminate thread: {e}")
418
433
 
419
- async def handle_cancel_action(self, run_id: str):
434
+ async def handle_cancel_action(self, run_id: str) -> None:
420
435
  with self.otel_tracer.start_as_current_span(
421
436
  "hatchet.worker.handle_cancel_action"
422
437
  ) as span:
@@ -427,7 +442,9 @@ class Runner:
427
442
  # call cancel to signal the context to stop
428
443
  if run_id in self.contexts:
429
444
  context = self.contexts.get(run_id)
430
- context.cancel()
445
+
446
+ if context:
447
+ context.cancel()
431
448
 
432
449
  await asyncio.sleep(1)
433
450
 
@@ -449,16 +466,20 @@ class Runner:
449
466
  span.add_event(f"Finished cancelling run id: {run_id}")
450
467
 
451
468
  def serialize_output(self, output: Any) -> str:
452
- output_bytes = ""
469
+
470
+ if isinstance(output, BaseModel):
471
+ return output.model_dump_json()
472
+
453
473
  if output is not None:
454
474
  try:
455
- output_bytes = json.dumps(output)
475
+ return json.dumps(output)
456
476
  except Exception as e:
457
477
  logger.error(f"Could not serialize output: {e}")
458
- output_bytes = str(output)
459
- return output_bytes
478
+ return str(output)
479
+
480
+ return ""
460
481
 
461
- async def wait_for_tasks(self):
482
+ async def wait_for_tasks(self) -> None:
462
483
  running = len(self.tasks.keys())
463
484
  while running > 0:
464
485
  logger.info(f"waiting for {running} tasks to finish...")
@@ -466,6 +487,6 @@ class Runner:
466
487
  running = len(self.tasks.keys())
467
488
 
468
489
 
469
- def errorWithTraceback(message: str, e: Exception):
490
+ def errorWithTraceback(message: str, e: Exception) -> str:
470
491
  trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
471
492
  return f"{message}\n{trace}"
@@ -1,22 +1,32 @@
1
1
  import asyncio
2
2
  import multiprocessing
3
+ import multiprocessing.context
3
4
  import os
4
5
  import signal
5
6
  import sys
7
+ from concurrent.futures import Future
6
8
  from dataclasses import dataclass, field
7
9
  from enum import Enum
8
- from multiprocessing import Process, Queue
9
- from typing import Any, Callable, Dict, Optional
10
+ from multiprocessing import Queue
11
+ from multiprocessing.process import BaseProcess
12
+ from types import FrameType
13
+ from typing import Any, Callable, TypeVar, get_type_hints
10
14
 
15
+ from hatchet_sdk import Context
11
16
  from hatchet_sdk.client import Client, new_client_raw
12
- from hatchet_sdk.context import Context
13
- from hatchet_sdk.contracts.workflows_pb2 import CreateWorkflowVersionOpts
17
+ from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
18
+ CreateWorkflowVersionOpts,
19
+ )
14
20
  from hatchet_sdk.loader import ClientConfig
15
21
  from hatchet_sdk.logger import logger
22
+ from hatchet_sdk.utils.types import WorkflowValidator
16
23
  from hatchet_sdk.v2.callable import HatchetCallable
24
+ from hatchet_sdk.v2.concurrency import ConcurrencyFunction
17
25
  from hatchet_sdk.worker.action_listener_process import worker_action_listener_process
18
26
  from hatchet_sdk.worker.runner.run_loop_manager import WorkerActionRunLoopManager
19
- from hatchet_sdk.workflow import WorkflowMeta
27
+ from hatchet_sdk.workflow import WorkflowInterface
28
+
29
+ T = TypeVar("T")
20
30
 
21
31
 
22
32
  class WorkerStatus(Enum):
@@ -28,46 +38,60 @@ class WorkerStatus(Enum):
28
38
 
29
39
  @dataclass
30
40
  class WorkerStartOptions:
31
- loop: asyncio.AbstractEventLoop = field(default=None)
41
+ loop: asyncio.AbstractEventLoop | None = field(default=None)
32
42
 
33
43
 
34
- @dataclass
35
44
  class Worker:
36
- name: str
37
- config: ClientConfig = field(default_factory=dict)
38
- max_runs: Optional[int] = None
39
- debug: bool = False
40
- labels: dict[str, str | int] = field(default_factory=dict)
41
- handle_kill: bool = True
42
-
43
- client: Client = field(init=False)
44
- tasks: Dict[str, asyncio.Task] = field(default_factory=dict)
45
- contexts: Dict[str, Context] = field(default_factory=dict)
46
- action_registry: Dict[str, Callable[..., Any]] = field(default_factory=dict)
47
- killing: bool = field(init=False, default=False)
48
- _status: WorkerStatus = field(init=False, default=WorkerStatus.INITIALIZED)
49
-
50
- action_listener_process: Process = field(init=False, default=None)
51
- action_listener_health_check: asyncio.Task = field(init=False, default=None)
52
- action_runner: WorkerActionRunLoopManager = field(init=False, default=None)
53
- ctx = multiprocessing.get_context("spawn")
54
-
55
- action_queue: Queue = field(init=False, default_factory=ctx.Queue)
56
- event_queue: Queue = field(init=False, default_factory=ctx.Queue)
57
-
58
- loop: asyncio.AbstractEventLoop = field(init=False, default=None)
59
- owned_loop: bool = True
60
-
61
- def __post_init__(self):
45
+ def __init__(
46
+ self,
47
+ name: str,
48
+ config: ClientConfig = ClientConfig(),
49
+ max_runs: int | None = None,
50
+ labels: dict[str, str | int] = {},
51
+ debug: bool = False,
52
+ owned_loop: bool = True,
53
+ handle_kill: bool = True,
54
+ ) -> None:
55
+ self.name = name
56
+ self.config = config
57
+ self.max_runs = max_runs
58
+ self.debug = debug
59
+ self.labels = labels
60
+ self.handle_kill = handle_kill
61
+ self.owned_loop = owned_loop
62
+
63
+ self.client: Client
64
+
65
+ self.action_registry: dict[str, Callable[[Context], T]] = {}
66
+ self.validator_registry: dict[str, WorkflowValidator] = {}
67
+
68
+ self.killing: bool = False
69
+ self._status: WorkerStatus
70
+
71
+ self.action_listener_process: BaseProcess
72
+ self.action_listener_health_check: asyncio.Task[Any]
73
+ self.action_runner: WorkerActionRunLoopManager
74
+
75
+ self.ctx = multiprocessing.get_context("spawn")
76
+
77
+ self.action_queue: "Queue[Any]" = self.ctx.Queue()
78
+ self.event_queue: "Queue[Any]" = self.ctx.Queue()
79
+
80
+ self.loop: asyncio.AbstractEventLoop
81
+
62
82
  self.client = new_client_raw(self.config, self.debug)
63
83
  self.name = self.client.config.namespace + self.name
64
- if self.owned_loop:
65
- self._setup_signal_handlers()
66
84
 
67
- def register_function(self, action: str, func: HatchetCallable):
85
+ self._setup_signal_handlers()
86
+
87
+ def register_function(
88
+ self, action: str, func: HatchetCallable[Any] | ConcurrencyFunction
89
+ ) -> None:
68
90
  self.action_registry[action] = func
69
91
 
70
- def register_workflow_from_opts(self, name: str, opts: CreateWorkflowVersionOpts):
92
+ def register_workflow_from_opts(
93
+ self, name: str, opts: CreateWorkflowVersionOpts
94
+ ) -> None:
71
95
  try:
72
96
  self.client.admin.put_workflow(opts.name, opts)
73
97
  except Exception as e:
@@ -75,7 +99,7 @@ class Worker:
75
99
  logger.error(e)
76
100
  sys.exit(1)
77
101
 
78
- def register_workflow(self, workflow: WorkflowMeta):
102
+ def register_workflow(self, workflow: WorkflowInterface) -> None:
79
103
  namespace = self.client.config.namespace
80
104
 
81
105
  try:
@@ -87,24 +111,30 @@ class Worker:
87
111
  logger.error(e)
88
112
  sys.exit(1)
89
113
 
90
- def create_action_function(action_func):
91
- def action_function(context):
114
+ def create_action_function(
115
+ action_func: Callable[..., T]
116
+ ) -> Callable[[Context], T]:
117
+ def action_function(context: Context) -> T:
92
118
  return action_func(workflow, context)
93
119
 
94
120
  if asyncio.iscoroutinefunction(action_func):
95
- action_function.is_coroutine = True
121
+ setattr(action_function, "is_coroutine", True)
96
122
  else:
97
- action_function.is_coroutine = False
123
+ setattr(action_function, "is_coroutine", False)
98
124
 
99
125
  return action_function
100
126
 
101
127
  for action_name, action_func in workflow.get_actions(namespace):
102
128
  self.action_registry[action_name] = create_action_function(action_func)
129
+ return_type = get_type_hints(action_func).get("return")
130
+ self.validator_registry[action_name] = WorkflowValidator(
131
+ workflow_input=workflow.input_validator, step_output=return_type
132
+ )
103
133
 
104
134
  def status(self) -> WorkerStatus:
105
135
  return self._status
106
136
 
107
- def setup_loop(self, loop: asyncio.AbstractEventLoop = None):
137
+ def setup_loop(self, loop: asyncio.AbstractEventLoop | None = None) -> bool:
108
138
  try:
109
139
  loop = loop or asyncio.get_running_loop()
110
140
  self.loop = loop
@@ -118,17 +148,22 @@ class Worker:
118
148
  created_loop = True
119
149
  return created_loop
120
150
 
121
- def start(self, options: WorkerStartOptions = WorkerStartOptions()):
151
+ def start(
152
+ self, options: WorkerStartOptions = WorkerStartOptions()
153
+ ) -> Future[asyncio.Task[Any] | None]:
122
154
  self.owned_loop = self.setup_loop(options.loop)
155
+
123
156
  f = asyncio.run_coroutine_threadsafe(
124
157
  self.async_start(options, _from_start=True), self.loop
125
158
  )
159
+
126
160
  # start the loop and wait until its closed
127
161
  if self.owned_loop:
128
162
  self.loop.run_forever()
129
163
 
130
164
  if self.handle_kill:
131
165
  sys.exit(0)
166
+
132
167
  return f
133
168
 
134
169
  ## Start methods
@@ -136,7 +171,7 @@ class Worker:
136
171
  self,
137
172
  options: WorkerStartOptions = WorkerStartOptions(),
138
173
  _from_start: bool = False,
139
- ):
174
+ ) -> Any | None:
140
175
  main_pid = os.getpid()
141
176
  logger.info("------------------------------------------")
142
177
  logger.info("STARTING HATCHET...")
@@ -148,25 +183,28 @@ class Worker:
148
183
  logger.error(
149
184
  "no actions registered, register workflows or actions before starting worker"
150
185
  )
151
- return
186
+ return None
152
187
 
153
188
  # non blocking setup
154
189
  if not _from_start:
155
190
  self.setup_loop(options.loop)
156
191
 
157
192
  self.action_listener_process = self._start_listener()
193
+
158
194
  self.action_runner = self._run_action_runner()
195
+
159
196
  self.action_listener_health_check = self.loop.create_task(
160
197
  self._check_listener_health()
161
198
  )
162
199
 
163
200
  return await self.action_listener_health_check
164
201
 
165
- def _run_action_runner(self):
202
+ def _run_action_runner(self) -> WorkerActionRunLoopManager:
166
203
  # Retrieve the shared queue
167
- runner = WorkerActionRunLoopManager(
204
+ return WorkerActionRunLoopManager(
168
205
  self.name,
169
206
  self.action_registry,
207
+ self.validator_registry,
170
208
  self.max_runs,
171
209
  self.config,
172
210
  self.action_queue,
@@ -177,10 +215,9 @@ class Worker:
177
215
  self.labels,
178
216
  )
179
217
 
180
- return runner
181
-
182
- def _start_listener(self):
218
+ def _start_listener(self) -> multiprocessing.context.SpawnProcess:
183
219
  action_list = [str(key) for key in self.action_registry.keys()]
220
+
184
221
  try:
185
222
  process = self.ctx.Process(
186
223
  target=worker_action_listener_process,
@@ -204,7 +241,7 @@ class Worker:
204
241
  logger.error(f"failed to start action listener: {e}")
205
242
  sys.exit(1)
206
243
 
207
- async def _check_listener_health(self):
244
+ async def _check_listener_health(self) -> None:
208
245
  logger.debug("starting action listener health check...")
209
246
  try:
210
247
  while not self.killing:
@@ -224,21 +261,21 @@ class Worker:
224
261
  logger.error(f"error checking listener health: {e}")
225
262
 
226
263
  ## Cleanup methods
227
- def _setup_signal_handlers(self):
264
+ def _setup_signal_handlers(self) -> None:
228
265
  signal.signal(signal.SIGTERM, self._handle_exit_signal)
229
266
  signal.signal(signal.SIGINT, self._handle_exit_signal)
230
267
  signal.signal(signal.SIGQUIT, self._handle_force_quit_signal)
231
268
 
232
- def _handle_exit_signal(self, signum, frame):
269
+ def _handle_exit_signal(self, signum: int, frame: FrameType | None) -> None:
233
270
  sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT"
234
271
  logger.info(f"received signal {sig_name}...")
235
272
  self.loop.create_task(self.exit_gracefully())
236
273
 
237
- def _handle_force_quit_signal(self, signum, frame):
274
+ def _handle_force_quit_signal(self, signum: int, frame: FrameType | None) -> None:
238
275
  logger.info("received SIGQUIT...")
239
276
  self.exit_forcefully()
240
277
 
241
- async def close(self):
278
+ async def close(self) -> None:
242
279
  logger.info(f"closing worker '{self.name}'...")
243
280
  self.killing = True
244
281
  # self.action_queue.close()
@@ -249,7 +286,7 @@ class Worker:
249
286
 
250
287
  await self.action_listener_health_check
251
288
 
252
- async def exit_gracefully(self):
289
+ async def exit_gracefully(self) -> None:
253
290
  logger.debug(f"gracefully stopping worker: {self.name}")
254
291
 
255
292
  if self.killing:
@@ -270,7 +307,7 @@ class Worker:
270
307
 
271
308
  logger.info("👋")
272
309
 
273
- def exit_forcefully(self):
310
+ def exit_forcefully(self) -> None:
274
311
  self.killing = True
275
312
 
276
313
  logger.debug(f"forcefully stopping worker: {self.name}")
@@ -286,7 +323,7 @@ class Worker:
286
323
  ) # Exit immediately TODO - should we exit with 1 here, there may be other workers to cleanup
287
324
 
288
325
 
289
- def register_on_worker(callable: HatchetCallable, worker: Worker):
326
+ def register_on_worker(callable: HatchetCallable[T], worker: Worker) -> None:
290
327
  worker.register_function(callable.get_action_name(), callable)
291
328
 
292
329
  if callable.function_on_failure is not None: