hatchet-sdk 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (73) hide show
  1. hatchet_sdk/__init__.py +32 -16
  2. hatchet_sdk/client.py +25 -63
  3. hatchet_sdk/clients/admin.py +203 -142
  4. hatchet_sdk/clients/dispatcher/action_listener.py +42 -42
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +18 -16
  6. hatchet_sdk/clients/durable_event_listener.py +327 -0
  7. hatchet_sdk/clients/rest/__init__.py +12 -1
  8. hatchet_sdk/clients/rest/api/log_api.py +258 -0
  9. hatchet_sdk/clients/rest/api/task_api.py +32 -6
  10. hatchet_sdk/clients/rest/api/workflow_runs_api.py +626 -0
  11. hatchet_sdk/clients/rest/models/__init__.py +12 -1
  12. hatchet_sdk/clients/rest/models/v1_log_line.py +94 -0
  13. hatchet_sdk/clients/rest/models/v1_log_line_level.py +39 -0
  14. hatchet_sdk/clients/rest/models/v1_log_line_list.py +110 -0
  15. hatchet_sdk/clients/rest/models/v1_task_summary.py +80 -64
  16. hatchet_sdk/clients/rest/models/v1_trigger_workflow_run_request.py +95 -0
  17. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name.py +98 -0
  18. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name_list.py +114 -0
  19. hatchet_sdk/clients/rest/models/workflow_run_shape_item_for_workflow_run_details.py +9 -4
  20. hatchet_sdk/clients/rest/models/workflow_runs_metrics.py +5 -1
  21. hatchet_sdk/clients/run_event_listener.py +0 -1
  22. hatchet_sdk/clients/v1/api_client.py +81 -0
  23. hatchet_sdk/context/context.py +86 -159
  24. hatchet_sdk/contracts/dispatcher_pb2_grpc.py +1 -1
  25. hatchet_sdk/contracts/events_pb2.py +2 -2
  26. hatchet_sdk/contracts/events_pb2_grpc.py +1 -1
  27. hatchet_sdk/contracts/v1/dispatcher_pb2.py +36 -0
  28. hatchet_sdk/contracts/v1/dispatcher_pb2.pyi +38 -0
  29. hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +145 -0
  30. hatchet_sdk/contracts/v1/shared/condition_pb2.py +39 -0
  31. hatchet_sdk/contracts/v1/shared/condition_pb2.pyi +72 -0
  32. hatchet_sdk/contracts/v1/shared/condition_pb2_grpc.py +29 -0
  33. hatchet_sdk/contracts/v1/workflows_pb2.py +67 -0
  34. hatchet_sdk/contracts/v1/workflows_pb2.pyi +228 -0
  35. hatchet_sdk/contracts/v1/workflows_pb2_grpc.py +234 -0
  36. hatchet_sdk/contracts/workflows_pb2_grpc.py +1 -1
  37. hatchet_sdk/features/cron.py +91 -121
  38. hatchet_sdk/features/logs.py +16 -0
  39. hatchet_sdk/features/metrics.py +75 -0
  40. hatchet_sdk/features/rate_limits.py +45 -0
  41. hatchet_sdk/features/runs.py +221 -0
  42. hatchet_sdk/features/scheduled.py +114 -131
  43. hatchet_sdk/features/workers.py +41 -0
  44. hatchet_sdk/features/workflows.py +55 -0
  45. hatchet_sdk/hatchet.py +463 -165
  46. hatchet_sdk/opentelemetry/instrumentor.py +8 -13
  47. hatchet_sdk/rate_limit.py +33 -39
  48. hatchet_sdk/runnables/contextvars.py +12 -0
  49. hatchet_sdk/runnables/standalone.py +192 -0
  50. hatchet_sdk/runnables/task.py +144 -0
  51. hatchet_sdk/runnables/types.py +138 -0
  52. hatchet_sdk/runnables/workflow.py +771 -0
  53. hatchet_sdk/utils/aio_utils.py +0 -79
  54. hatchet_sdk/utils/proto_enums.py +0 -7
  55. hatchet_sdk/utils/timedelta_to_expression.py +23 -0
  56. hatchet_sdk/utils/typing.py +2 -2
  57. hatchet_sdk/v0/clients/rest_client.py +9 -0
  58. hatchet_sdk/v0/worker/action_listener_process.py +18 -2
  59. hatchet_sdk/waits.py +120 -0
  60. hatchet_sdk/worker/action_listener_process.py +64 -30
  61. hatchet_sdk/worker/runner/run_loop_manager.py +35 -26
  62. hatchet_sdk/worker/runner/runner.py +72 -55
  63. hatchet_sdk/worker/runner/utils/capture_logs.py +3 -11
  64. hatchet_sdk/worker/worker.py +155 -118
  65. hatchet_sdk/workflow_run.py +4 -5
  66. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/METADATA +1 -2
  67. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/RECORD +69 -43
  68. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/entry_points.txt +2 -0
  69. hatchet_sdk/clients/rest_client.py +0 -636
  70. hatchet_sdk/semver.py +0 -30
  71. hatchet_sdk/worker/runner/utils/error_with_traceback.py +0 -6
  72. hatchet_sdk/workflow.py +0 -527
  73. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/WHEEL +0 -0
@@ -1,83 +1,4 @@
1
1
  import asyncio
2
- import inspect
3
- from concurrent.futures import Executor
4
- from functools import partial, wraps
5
- from typing import Any
6
-
7
-
8
- ## TODO: Stricter typing here
9
- def sync_to_async(func: Any) -> Any:
10
- """
11
- A decorator to run a synchronous function or coroutine in an asynchronous context with added
12
- asyncio loop safety.
13
-
14
- This decorator allows you to safely call synchronous functions or coroutines from an
15
- asynchronous function by running them in an executor.
16
-
17
- Args:
18
- func (callable): The synchronous function or coroutine to be run asynchronously.
19
-
20
- Returns:
21
- callable: An asynchronous wrapper function that runs the given function in an executor.
22
-
23
- Example:
24
- @sync_to_async
25
- def sync_function(x, y):
26
- return x + y
27
-
28
- @sync_to_async
29
- async def async_function(x, y):
30
- return x + y
31
-
32
-
33
- def undecorated_function(x, y):
34
- return x + y
35
-
36
- async def main():
37
- result1 = await sync_function(1, 2)
38
- result2 = await async_function(3, 4)
39
- result3 = await sync_to_async(undecorated_function)(5, 6)
40
- print(result1, result2, result3)
41
-
42
- asyncio.run(main())
43
- """
44
-
45
- ## TODO: Stricter typing here
46
- @wraps(func)
47
- async def run(
48
- *args: Any,
49
- loop: asyncio.AbstractEventLoop | None = None,
50
- executor: Executor | None = None,
51
- **kwargs: Any
52
- ) -> Any:
53
- """
54
- The asynchronous wrapper function that runs the given function in an executor.
55
-
56
- Args:
57
- *args: Positional arguments to pass to the function.
58
- loop (asyncio.AbstractEventLoop, optional): The event loop to use. If None, the current running loop is used.
59
- executor (concurrent.futures.Executor, optional): The executor to use. If None, the default executor is used.
60
- **kwargs: Keyword arguments to pass to the function.
61
-
62
- Returns:
63
- The result of the function call.
64
- """
65
- if loop is None:
66
- loop = asyncio.get_running_loop()
67
-
68
- if inspect.iscoroutinefunction(func):
69
- # Wrap the coroutine to run it in an executor
70
- async def wrapper() -> Any:
71
- return await func(*args, **kwargs)
72
-
73
- pfunc = partial(asyncio.run, wrapper())
74
- return await loop.run_in_executor(executor, pfunc)
75
- else:
76
- # Run the synchronous function in an executor
77
- pfunc = partial(func, *args, **kwargs)
78
- return await loop.run_in_executor(executor, pfunc)
79
-
80
- return run
81
2
 
82
3
 
83
4
  def get_active_event_loop() -> asyncio.AbstractEventLoop | None:
@@ -45,10 +45,3 @@ def convert_proto_enum_to_python(
45
45
  return None
46
46
 
47
47
  return python_enum_class[proto_enum.Name(value)]
48
-
49
-
50
- def maybe_int_to_str(value: int | None) -> str | None:
51
- if value is None:
52
- return None
53
-
54
- return str(value)
@@ -0,0 +1,23 @@
1
+ from datetime import timedelta
2
+
3
+ DAY = 86400
4
+ HOUR = 3600
5
+ MINUTE = 60
6
+
7
+ Duration = timedelta | str
8
+
9
+
10
+ def timedelta_to_expr(td: Duration) -> str:
11
+ if isinstance(td, str):
12
+ return td
13
+
14
+ seconds = td.seconds
15
+
16
+ if seconds % DAY == 0:
17
+ return f"{seconds // DAY}d"
18
+ elif seconds % HOUR == 0:
19
+ return f"{seconds // HOUR}h"
20
+ elif seconds % MINUTE == 0:
21
+ return f"{seconds // MINUTE}m"
22
+ else:
23
+ return f"{seconds}s"
@@ -1,11 +1,11 @@
1
- from typing import Any, Mapping, Type, TypeVar
1
+ from typing import Any, Mapping, Type, TypeGuard, TypeVar
2
2
 
3
3
  from pydantic import BaseModel
4
4
 
5
5
  T = TypeVar("T", bound=BaseModel)
6
6
 
7
7
 
8
- def is_basemodel_subclass(model: Any) -> bool:
8
+ def is_basemodel_subclass(model: Any) -> TypeGuard[Type[BaseModel]]:
9
9
  try:
10
10
  return issubclass(model, BaseModel)
11
11
  except TypeError:
@@ -9,6 +9,7 @@ from pydantic import StrictInt
9
9
  from hatchet_sdk.v0.clients.rest.api.event_api import EventApi
10
10
  from hatchet_sdk.v0.clients.rest.api.log_api import LogApi
11
11
  from hatchet_sdk.v0.clients.rest.api.step_run_api import StepRunApi
12
+ from hatchet_sdk.v0.clients.rest.api.worker_api import WorkerApi
12
13
  from hatchet_sdk.v0.clients.rest.api.workflow_api import WorkflowApi
13
14
  from hatchet_sdk.v0.clients.rest.api.workflow_run_api import WorkflowRunApi
14
15
  from hatchet_sdk.v0.clients.rest.api.workflow_runs_api import WorkflowRunsApi
@@ -85,6 +86,7 @@ class AsyncRestApi:
85
86
  self._step_run_api = None
86
87
  self._event_api = None
87
88
  self._log_api = None
89
+ self._worker_api = None
88
90
 
89
91
  @property
90
92
  def api_client(self):
@@ -104,6 +106,13 @@ class AsyncRestApi:
104
106
  self._workflow_run_api = WorkflowRunApi(self.api_client)
105
107
  return self._workflow_run_api
106
108
 
109
+ @property
110
+ def worker_api(self):
111
+ if self._worker_api is None:
112
+ self._worker_api = WorkerApi(self.api_client)
113
+
114
+ return self._worker_api
115
+
107
116
  @property
108
117
  def step_run_api(self):
109
118
  if self._step_run_api is None:
@@ -8,12 +8,14 @@ from typing import Any, List, Mapping, Optional
8
8
 
9
9
  import grpc
10
10
 
11
+ from hatchet_sdk.v0.client import Client, new_client_raw
11
12
  from hatchet_sdk.v0.clients.dispatcher.action_listener import Action
12
13
  from hatchet_sdk.v0.clients.dispatcher.dispatcher import (
13
14
  ActionListener,
14
15
  GetActionListenerRequest,
15
16
  new_dispatcher,
16
17
  )
18
+ from hatchet_sdk.v0.clients.rest.models.update_worker_request import UpdateWorkerRequest
17
19
  from hatchet_sdk.v0.contracts.dispatcher_pb2 import (
18
20
  GROUP_KEY_EVENT_TYPE_STARTED,
19
21
  STEP_EVENT_TYPE_STARTED,
@@ -70,9 +72,15 @@ class WorkerActionListenerProcess:
70
72
  if self.debug:
71
73
  logger.setLevel(logging.DEBUG)
72
74
 
75
+ self.client = new_client_raw(self.config, self.debug)
76
+
73
77
  loop = asyncio.get_event_loop()
74
- loop.add_signal_handler(signal.SIGINT, noop_handler)
75
- loop.add_signal_handler(signal.SIGTERM, noop_handler)
78
+ loop.add_signal_handler(
79
+ signal.SIGINT, lambda: asyncio.create_task(self.pause_task_assignment())
80
+ )
81
+ loop.add_signal_handler(
82
+ signal.SIGTERM, lambda: asyncio.create_task(self.pause_task_assignment())
83
+ )
76
84
  loop.add_signal_handler(
77
85
  signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully())
78
86
  )
@@ -249,7 +257,15 @@ class WorkerActionListenerProcess:
249
257
 
250
258
  self.event_queue.put(STOP_LOOP)
251
259
 
260
+ async def pause_task_assignment(self) -> None:
261
+ await self.client.rest.aio.worker_api.worker_update(
262
+ worker=self.listener.worker_id,
263
+ update_worker_request=UpdateWorkerRequest(isPaused=True),
264
+ )
265
+
252
266
  async def exit_gracefully(self, skip_unregister=False):
267
+ await self.pause_task_assignment()
268
+
253
269
  if self.killing:
254
270
  return
255
271
 
hatchet_sdk/waits.py ADDED
@@ -0,0 +1,120 @@
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+ from typing import TYPE_CHECKING
4
+ from uuid import uuid4
5
+
6
+ from pydantic import BaseModel, Field
7
+
8
+ from hatchet_sdk.contracts.v1.shared.condition_pb2 import Action as ProtoAction
9
+ from hatchet_sdk.contracts.v1.shared.condition_pb2 import (
10
+ BaseMatchCondition,
11
+ ParentOverrideMatchCondition,
12
+ SleepMatchCondition,
13
+ UserEventMatchCondition,
14
+ )
15
+ from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
16
+ from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
17
+
18
+ if TYPE_CHECKING:
19
+ from hatchet_sdk.runnables.task import Task
20
+ from hatchet_sdk.runnables.types import R, TWorkflowInput
21
+
22
+
23
+ def generate_or_group_id() -> str:
24
+ return str(uuid4())
25
+
26
+
27
+ class Action(Enum):
28
+ CREATE = 0
29
+ QUEUE = 1
30
+ CANCEL = 2
31
+ SKIP = 3
32
+
33
+
34
+ class BaseCondition(BaseModel):
35
+ readable_data_key: str
36
+ action: Action | None = None
37
+ or_group_id: str = Field(default_factory=generate_or_group_id)
38
+ expression: str | None = None
39
+
40
+ def to_pb(self) -> BaseMatchCondition:
41
+ return BaseMatchCondition(
42
+ readable_data_key=self.readable_data_key,
43
+ action=convert_python_enum_to_proto(self.action, ProtoAction), # type: ignore[arg-type]
44
+ or_group_id=self.or_group_id,
45
+ expression=self.expression,
46
+ )
47
+
48
+
49
+ class Condition(ABC):
50
+ def __init__(self, base: BaseCondition):
51
+ self.base = base
52
+
53
+ @abstractmethod
54
+ def to_pb(
55
+ self,
56
+ ) -> UserEventMatchCondition | ParentOverrideMatchCondition | SleepMatchCondition:
57
+ pass
58
+
59
+
60
+ class SleepCondition(Condition):
61
+ def __init__(self, duration: Duration) -> None:
62
+ super().__init__(
63
+ BaseCondition(
64
+ readable_data_key=f"sleep:{timedelta_to_expr(duration)}",
65
+ )
66
+ )
67
+
68
+ self.duration = duration
69
+
70
+ def to_pb(self) -> SleepMatchCondition:
71
+ return SleepMatchCondition(
72
+ base=self.base.to_pb(),
73
+ sleep_for=timedelta_to_expr(self.duration),
74
+ )
75
+
76
+
77
+ class UserEventCondition(Condition):
78
+ def __init__(self, event_key: str, expression: str | None = None) -> None:
79
+ super().__init__(
80
+ BaseCondition(
81
+ readable_data_key=event_key,
82
+ expression=expression,
83
+ )
84
+ )
85
+
86
+ self.event_key = event_key
87
+ self.expression = expression
88
+
89
+ def to_pb(self) -> UserEventMatchCondition:
90
+ return UserEventMatchCondition(
91
+ base=self.base.to_pb(),
92
+ user_event_key=self.event_key,
93
+ )
94
+
95
+
96
+ class ParentCondition(Condition):
97
+ def __init__(
98
+ self, parent: "Task[TWorkflowInput, R]", expression: str | None = None
99
+ ) -> None:
100
+ super().__init__(
101
+ BaseCondition(readable_data_key=parent.name, expression=expression)
102
+ )
103
+
104
+ self.parent = parent
105
+
106
+ def to_pb(self) -> ParentOverrideMatchCondition:
107
+ return ParentOverrideMatchCondition(
108
+ base=self.base.to_pb(),
109
+ parent_readable_id=self.parent.name,
110
+ )
111
+
112
+
113
+ class OrGroup:
114
+ def __init__(self, conditions: list[Condition]) -> None:
115
+ self.or_group_id = generate_or_group_id()
116
+ self.conditions = conditions
117
+
118
+
119
+ def or_(*conditions: Condition) -> OrGroup:
120
+ return OrGroup(conditions=list(conditions))
@@ -2,12 +2,13 @@ import asyncio
2
2
  import logging
3
3
  import signal
4
4
  import time
5
- from dataclasses import dataclass, field
5
+ from dataclasses import dataclass
6
6
  from multiprocessing import Queue
7
- from typing import Any, List, Literal
7
+ from typing import Any, Literal
8
8
 
9
9
  import grpc
10
10
 
11
+ from hatchet_sdk.client import Client
11
12
  from hatchet_sdk.clients.dispatcher.action_listener import (
12
13
  Action,
13
14
  ActionListener,
@@ -15,12 +16,18 @@ from hatchet_sdk.clients.dispatcher.action_listener import (
15
16
  GetActionListenerRequest,
16
17
  )
17
18
  from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
19
+ from hatchet_sdk.clients.rest.models.update_worker_request import UpdateWorkerRequest
18
20
  from hatchet_sdk.config import ClientConfig
19
21
  from hatchet_sdk.contracts.dispatcher_pb2 import (
20
22
  GROUP_KEY_EVENT_TYPE_STARTED,
21
23
  STEP_EVENT_TYPE_STARTED,
22
24
  )
23
25
  from hatchet_sdk.logger import logger
26
+ from hatchet_sdk.runnables.contextvars import (
27
+ ctx_step_run_id,
28
+ ctx_worker_id,
29
+ ctx_workflow_run_id,
30
+ )
24
31
  from hatchet_sdk.utils.backoff import exp_backoff_sleep
25
32
 
26
33
  ACTION_EVENT_RETRY_COUNT = 5
@@ -42,42 +49,60 @@ BLOCKED_THREAD_WARNING = (
42
49
  )
43
50
 
44
51
 
45
- def noop_handler() -> None:
46
- pass
47
-
48
-
49
- @dataclass
50
52
  class WorkerActionListenerProcess:
51
- name: str
52
- actions: List[str]
53
- max_runs: int
54
- config: ClientConfig
55
- action_queue: "Queue[Action]"
56
- event_queue: "Queue[ActionEvent | STOP_LOOP_TYPE]"
57
- handle_kill: bool = True
58
- debug: bool = False
59
- labels: dict[str, str | int] = field(default_factory=dict)
60
-
61
- listener: ActionListener = field(init=False)
53
+ def __init__(
54
+ self,
55
+ name: str,
56
+ actions: list[str],
57
+ slots: int,
58
+ config: ClientConfig,
59
+ action_queue: "Queue[Action]",
60
+ event_queue: "Queue[ActionEvent | STOP_LOOP_TYPE]",
61
+ handle_kill: bool = True,
62
+ debug: bool = False,
63
+ labels: dict[str, str | int] = {},
64
+ ) -> None:
65
+ self.name = name
66
+ self.actions = actions
67
+ self.slots = slots
68
+ self.config = config
69
+ self.action_queue = action_queue
70
+ self.event_queue = event_queue
71
+ self.debug = debug
72
+ self.labels = labels
73
+ self.handle_kill = handle_kill
74
+
75
+ self.listener: ActionListener | None = None
76
+ self.killing = False
77
+ self.action_loop_task: asyncio.Task[None] | None = None
78
+ self.event_send_loop_task: asyncio.Task[None] | None = None
79
+ self.running_step_runs: dict[str, float] = {}
62
80
 
63
- killing: bool = field(init=False, default=False)
64
-
65
- action_loop_task: asyncio.Task[None] | None = field(init=False, default=None)
66
- event_send_loop_task: asyncio.Task[None] | None = field(init=False, default=None)
67
-
68
- running_step_runs: dict[str, float] = field(init=False, default_factory=dict)
69
-
70
- def __post_init__(self) -> None:
71
81
  if self.debug:
72
82
  logger.setLevel(logging.DEBUG)
73
83
 
84
+ self.client = Client(config=self.config, debug=self.debug)
85
+
74
86
  loop = asyncio.get_event_loop()
75
- loop.add_signal_handler(signal.SIGINT, noop_handler)
76
- loop.add_signal_handler(signal.SIGTERM, noop_handler)
87
+ loop.add_signal_handler(
88
+ signal.SIGINT, lambda: asyncio.create_task(self.pause_task_assignment())
89
+ )
90
+ loop.add_signal_handler(
91
+ signal.SIGTERM, lambda: asyncio.create_task(self.pause_task_assignment())
92
+ )
77
93
  loop.add_signal_handler(
78
94
  signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully())
79
95
  )
80
96
 
97
+ async def pause_task_assignment(self) -> None:
98
+ if self.listener is None:
99
+ raise ValueError("listener not started")
100
+
101
+ await self.client.workers.aio_update(
102
+ worker_id=self.listener.worker_id,
103
+ opts=UpdateWorkerRequest(isPaused=True),
104
+ )
105
+
81
106
  async def start(self, retry_attempt: int = 0) -> None:
82
107
  if retry_attempt > 5:
83
108
  logger.error("could not start action listener")
@@ -93,8 +118,8 @@ class WorkerActionListenerProcess:
93
118
  worker_name=self.name,
94
119
  services=["default"],
95
120
  actions=self.actions,
96
- max_runs=self.max_runs,
97
- _labels=self.labels,
121
+ slots=self.slots,
122
+ raw_labels=self.labels,
98
123
  )
99
124
  )
100
125
 
@@ -190,11 +215,18 @@ class WorkerActionListenerProcess:
190
215
  return time.time()
191
216
 
192
217
  async def start_action_loop(self) -> None:
218
+ if self.listener is None:
219
+ raise ValueError("listener not started")
220
+
193
221
  try:
194
222
  async for action in self.listener:
195
223
  if action is None:
196
224
  break
197
225
 
226
+ ctx_step_run_id.set(action.step_run_id)
227
+ ctx_workflow_run_id.set(action.workflow_run_id)
228
+ ctx_worker_id.set(action.worker_id)
229
+
198
230
  # Process the action here
199
231
  match action.action_type:
200
232
  case ActionType.START_STEP_RUN:
@@ -253,6 +285,8 @@ class WorkerActionListenerProcess:
253
285
  self.event_queue.put(STOP_LOOP)
254
286
 
255
287
  async def exit_gracefully(self) -> None:
288
+ await self.pause_task_assignment()
289
+
256
290
  if self.killing:
257
291
  return
258
292
 
@@ -1,18 +1,17 @@
1
1
  import asyncio
2
2
  import logging
3
- from dataclasses import dataclass, field
4
3
  from multiprocessing import Queue
5
4
  from typing import Any, Literal, TypeVar
6
5
 
7
- from hatchet_sdk.client import Client, new_client_raw
6
+ from hatchet_sdk.client import Client
8
7
  from hatchet_sdk.clients.dispatcher.action_listener import Action
9
8
  from hatchet_sdk.config import ClientConfig
10
9
  from hatchet_sdk.logger import logger
10
+ from hatchet_sdk.runnables.task import Task
11
11
  from hatchet_sdk.utils.typing import WorkflowValidator
12
12
  from hatchet_sdk.worker.action_listener_process import ActionEvent
13
13
  from hatchet_sdk.worker.runner.runner import Runner
14
14
  from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
15
- from hatchet_sdk.workflow import Step
16
15
 
17
16
  STOP_LOOP_TYPE = Literal["STOP_LOOP"]
18
17
  STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP"
@@ -20,29 +19,40 @@ STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP"
20
19
  T = TypeVar("T")
21
20
 
22
21
 
23
- @dataclass
24
22
  class WorkerActionRunLoopManager:
25
- name: str
26
- action_registry: dict[str, Step[Any]]
27
- validator_registry: dict[str, WorkflowValidator]
28
- max_runs: int | None
29
- config: ClientConfig
30
- action_queue: "Queue[Action | STOP_LOOP_TYPE]"
31
- event_queue: "Queue[ActionEvent]"
32
- loop: asyncio.AbstractEventLoop
33
- handle_kill: bool = True
34
- debug: bool = False
35
- labels: dict[str, str | int] = field(default_factory=dict)
36
-
37
- client: Client = field(init=False)
38
-
39
- killing: bool = field(init=False, default=False)
40
- runner: Runner | None = field(init=False, default=None)
41
-
42
- def __post_init__(self) -> None:
23
+ def __init__(
24
+ self,
25
+ name: str,
26
+ action_registry: dict[str, Task[Any, Any]],
27
+ validator_registry: dict[str, WorkflowValidator],
28
+ slots: int | None,
29
+ config: ClientConfig,
30
+ action_queue: "Queue[Action | STOP_LOOP_TYPE]",
31
+ event_queue: "Queue[ActionEvent]",
32
+ loop: asyncio.AbstractEventLoop,
33
+ handle_kill: bool = True,
34
+ debug: bool = False,
35
+ labels: dict[str, str | int] = {},
36
+ ) -> None:
37
+ self.name = name
38
+ self.action_registry = action_registry
39
+ self.validator_registry = validator_registry
40
+ self.slots = slots
41
+ self.config = config
42
+ self.action_queue = action_queue
43
+ self.event_queue = event_queue
44
+ self.loop = loop
45
+ self.handle_kill = handle_kill
46
+ self.debug = debug
47
+ self.labels = labels
48
+
43
49
  if self.debug:
44
50
  logger.setLevel(logging.DEBUG)
45
- self.client = new_client_raw(self.config, self.debug)
51
+
52
+ self.killing = False
53
+ self.runner: Runner | None = None
54
+
55
+ self.client = Client(config=self.config, debug=self.debug)
46
56
  self.start()
47
57
 
48
58
  def start(self, retry_count: int = 1) -> None:
@@ -73,13 +83,12 @@ class WorkerActionRunLoopManager:
73
83
 
74
84
  async def _start_action_loop(self) -> None:
75
85
  self.runner = Runner(
76
- self.name,
77
86
  self.event_queue,
78
- self.max_runs,
87
+ self.config,
88
+ self.slots,
79
89
  self.handle_kill,
80
90
  self.action_registry,
81
91
  self.validator_registry,
82
- self.config,
83
92
  self.labels,
84
93
  )
85
94