hatchet-sdk 1.0.0__py3-none-any.whl → 1.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hatchet-sdk might be problematic. Click here for more details.
- hatchet_sdk/__init__.py +27 -16
- hatchet_sdk/client.py +13 -63
- hatchet_sdk/clients/admin.py +203 -124
- hatchet_sdk/clients/dispatcher/action_listener.py +42 -42
- hatchet_sdk/clients/dispatcher/dispatcher.py +18 -16
- hatchet_sdk/clients/durable_event_listener.py +327 -0
- hatchet_sdk/clients/rest/__init__.py +12 -1
- hatchet_sdk/clients/rest/api/log_api.py +258 -0
- hatchet_sdk/clients/rest/api/task_api.py +32 -6
- hatchet_sdk/clients/rest/api/workflow_runs_api.py +626 -0
- hatchet_sdk/clients/rest/models/__init__.py +12 -1
- hatchet_sdk/clients/rest/models/v1_log_line.py +94 -0
- hatchet_sdk/clients/rest/models/v1_log_line_level.py +39 -0
- hatchet_sdk/clients/rest/models/v1_log_line_list.py +110 -0
- hatchet_sdk/clients/rest/models/v1_task_summary.py +80 -64
- hatchet_sdk/clients/rest/models/v1_trigger_workflow_run_request.py +95 -0
- hatchet_sdk/clients/rest/models/v1_workflow_run_display_name.py +98 -0
- hatchet_sdk/clients/rest/models/v1_workflow_run_display_name_list.py +114 -0
- hatchet_sdk/clients/rest/models/workflow_run_shape_item_for_workflow_run_details.py +9 -4
- hatchet_sdk/clients/rest_client.py +21 -0
- hatchet_sdk/clients/run_event_listener.py +0 -1
- hatchet_sdk/context/context.py +85 -147
- hatchet_sdk/contracts/dispatcher_pb2_grpc.py +1 -1
- hatchet_sdk/contracts/events_pb2.py +2 -2
- hatchet_sdk/contracts/events_pb2_grpc.py +1 -1
- hatchet_sdk/contracts/v1/dispatcher_pb2.py +36 -0
- hatchet_sdk/contracts/v1/dispatcher_pb2.pyi +38 -0
- hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +145 -0
- hatchet_sdk/contracts/v1/shared/condition_pb2.py +39 -0
- hatchet_sdk/contracts/v1/shared/condition_pb2.pyi +72 -0
- hatchet_sdk/contracts/v1/shared/condition_pb2_grpc.py +29 -0
- hatchet_sdk/contracts/v1/workflows_pb2.py +67 -0
- hatchet_sdk/contracts/v1/workflows_pb2.pyi +228 -0
- hatchet_sdk/contracts/v1/workflows_pb2_grpc.py +234 -0
- hatchet_sdk/contracts/workflows_pb2_grpc.py +1 -1
- hatchet_sdk/features/cron.py +3 -3
- hatchet_sdk/features/scheduled.py +2 -2
- hatchet_sdk/hatchet.py +427 -151
- hatchet_sdk/opentelemetry/instrumentor.py +8 -13
- hatchet_sdk/rate_limit.py +33 -39
- hatchet_sdk/runnables/contextvars.py +12 -0
- hatchet_sdk/runnables/standalone.py +194 -0
- hatchet_sdk/runnables/task.py +144 -0
- hatchet_sdk/runnables/types.py +138 -0
- hatchet_sdk/runnables/workflow.py +764 -0
- hatchet_sdk/utils/aio_utils.py +0 -79
- hatchet_sdk/utils/proto_enums.py +0 -7
- hatchet_sdk/utils/timedelta_to_expression.py +23 -0
- hatchet_sdk/utils/typing.py +2 -2
- hatchet_sdk/v0/clients/rest_client.py +9 -0
- hatchet_sdk/v0/worker/action_listener_process.py +18 -2
- hatchet_sdk/waits.py +120 -0
- hatchet_sdk/worker/action_listener_process.py +64 -30
- hatchet_sdk/worker/runner/run_loop_manager.py +35 -25
- hatchet_sdk/worker/runner/runner.py +72 -49
- hatchet_sdk/worker/runner/utils/capture_logs.py +3 -11
- hatchet_sdk/worker/worker.py +155 -118
- hatchet_sdk/workflow_run.py +4 -5
- {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/METADATA +1 -2
- {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/RECORD +62 -42
- {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/entry_points.txt +2 -0
- hatchet_sdk/semver.py +0 -30
- hatchet_sdk/worker/runner/utils/error_with_traceback.py +0 -6
- hatchet_sdk/workflow.py +0 -527
- {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/WHEEL +0 -0
hatchet_sdk/utils/aio_utils.py
CHANGED
|
@@ -1,83 +1,4 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import inspect
|
|
3
|
-
from concurrent.futures import Executor
|
|
4
|
-
from functools import partial, wraps
|
|
5
|
-
from typing import Any
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
## TODO: Stricter typing here
|
|
9
|
-
def sync_to_async(func: Any) -> Any:
|
|
10
|
-
"""
|
|
11
|
-
A decorator to run a synchronous function or coroutine in an asynchronous context with added
|
|
12
|
-
asyncio loop safety.
|
|
13
|
-
|
|
14
|
-
This decorator allows you to safely call synchronous functions or coroutines from an
|
|
15
|
-
asynchronous function by running them in an executor.
|
|
16
|
-
|
|
17
|
-
Args:
|
|
18
|
-
func (callable): The synchronous function or coroutine to be run asynchronously.
|
|
19
|
-
|
|
20
|
-
Returns:
|
|
21
|
-
callable: An asynchronous wrapper function that runs the given function in an executor.
|
|
22
|
-
|
|
23
|
-
Example:
|
|
24
|
-
@sync_to_async
|
|
25
|
-
def sync_function(x, y):
|
|
26
|
-
return x + y
|
|
27
|
-
|
|
28
|
-
@sync_to_async
|
|
29
|
-
async def async_function(x, y):
|
|
30
|
-
return x + y
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def undecorated_function(x, y):
|
|
34
|
-
return x + y
|
|
35
|
-
|
|
36
|
-
async def main():
|
|
37
|
-
result1 = await sync_function(1, 2)
|
|
38
|
-
result2 = await async_function(3, 4)
|
|
39
|
-
result3 = await sync_to_async(undecorated_function)(5, 6)
|
|
40
|
-
print(result1, result2, result3)
|
|
41
|
-
|
|
42
|
-
asyncio.run(main())
|
|
43
|
-
"""
|
|
44
|
-
|
|
45
|
-
## TODO: Stricter typing here
|
|
46
|
-
@wraps(func)
|
|
47
|
-
async def run(
|
|
48
|
-
*args: Any,
|
|
49
|
-
loop: asyncio.AbstractEventLoop | None = None,
|
|
50
|
-
executor: Executor | None = None,
|
|
51
|
-
**kwargs: Any
|
|
52
|
-
) -> Any:
|
|
53
|
-
"""
|
|
54
|
-
The asynchronous wrapper function that runs the given function in an executor.
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
*args: Positional arguments to pass to the function.
|
|
58
|
-
loop (asyncio.AbstractEventLoop, optional): The event loop to use. If None, the current running loop is used.
|
|
59
|
-
executor (concurrent.futures.Executor, optional): The executor to use. If None, the default executor is used.
|
|
60
|
-
**kwargs: Keyword arguments to pass to the function.
|
|
61
|
-
|
|
62
|
-
Returns:
|
|
63
|
-
The result of the function call.
|
|
64
|
-
"""
|
|
65
|
-
if loop is None:
|
|
66
|
-
loop = asyncio.get_running_loop()
|
|
67
|
-
|
|
68
|
-
if inspect.iscoroutinefunction(func):
|
|
69
|
-
# Wrap the coroutine to run it in an executor
|
|
70
|
-
async def wrapper() -> Any:
|
|
71
|
-
return await func(*args, **kwargs)
|
|
72
|
-
|
|
73
|
-
pfunc = partial(asyncio.run, wrapper())
|
|
74
|
-
return await loop.run_in_executor(executor, pfunc)
|
|
75
|
-
else:
|
|
76
|
-
# Run the synchronous function in an executor
|
|
77
|
-
pfunc = partial(func, *args, **kwargs)
|
|
78
|
-
return await loop.run_in_executor(executor, pfunc)
|
|
79
|
-
|
|
80
|
-
return run
|
|
81
2
|
|
|
82
3
|
|
|
83
4
|
def get_active_event_loop() -> asyncio.AbstractEventLoop | None:
|
hatchet_sdk/utils/proto_enums.py
CHANGED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from datetime import timedelta
|
|
2
|
+
|
|
3
|
+
DAY = 86400
|
|
4
|
+
HOUR = 3600
|
|
5
|
+
MINUTE = 60
|
|
6
|
+
|
|
7
|
+
Duration = timedelta | str
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def timedelta_to_expr(td: Duration) -> str:
|
|
11
|
+
if isinstance(td, str):
|
|
12
|
+
return td
|
|
13
|
+
|
|
14
|
+
seconds = td.seconds
|
|
15
|
+
|
|
16
|
+
if seconds % DAY == 0:
|
|
17
|
+
return f"{seconds // DAY}d"
|
|
18
|
+
elif seconds % HOUR == 0:
|
|
19
|
+
return f"{seconds // HOUR}h"
|
|
20
|
+
elif seconds % MINUTE == 0:
|
|
21
|
+
return f"{seconds // MINUTE}m"
|
|
22
|
+
else:
|
|
23
|
+
return f"{seconds}s"
|
hatchet_sdk/utils/typing.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
from typing import Any, Mapping, Type, TypeVar
|
|
1
|
+
from typing import Any, Mapping, Type, TypeGuard, TypeVar
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel
|
|
4
4
|
|
|
5
5
|
T = TypeVar("T", bound=BaseModel)
|
|
6
6
|
|
|
7
7
|
|
|
8
|
-
def is_basemodel_subclass(model: Any) ->
|
|
8
|
+
def is_basemodel_subclass(model: Any) -> TypeGuard[Type[BaseModel]]:
|
|
9
9
|
try:
|
|
10
10
|
return issubclass(model, BaseModel)
|
|
11
11
|
except TypeError:
|
|
@@ -9,6 +9,7 @@ from pydantic import StrictInt
|
|
|
9
9
|
from hatchet_sdk.v0.clients.rest.api.event_api import EventApi
|
|
10
10
|
from hatchet_sdk.v0.clients.rest.api.log_api import LogApi
|
|
11
11
|
from hatchet_sdk.v0.clients.rest.api.step_run_api import StepRunApi
|
|
12
|
+
from hatchet_sdk.v0.clients.rest.api.worker_api import WorkerApi
|
|
12
13
|
from hatchet_sdk.v0.clients.rest.api.workflow_api import WorkflowApi
|
|
13
14
|
from hatchet_sdk.v0.clients.rest.api.workflow_run_api import WorkflowRunApi
|
|
14
15
|
from hatchet_sdk.v0.clients.rest.api.workflow_runs_api import WorkflowRunsApi
|
|
@@ -85,6 +86,7 @@ class AsyncRestApi:
|
|
|
85
86
|
self._step_run_api = None
|
|
86
87
|
self._event_api = None
|
|
87
88
|
self._log_api = None
|
|
89
|
+
self._worker_api = None
|
|
88
90
|
|
|
89
91
|
@property
|
|
90
92
|
def api_client(self):
|
|
@@ -104,6 +106,13 @@ class AsyncRestApi:
|
|
|
104
106
|
self._workflow_run_api = WorkflowRunApi(self.api_client)
|
|
105
107
|
return self._workflow_run_api
|
|
106
108
|
|
|
109
|
+
@property
|
|
110
|
+
def worker_api(self):
|
|
111
|
+
if self._worker_api is None:
|
|
112
|
+
self._worker_api = WorkerApi(self.api_client)
|
|
113
|
+
|
|
114
|
+
return self._worker_api
|
|
115
|
+
|
|
107
116
|
@property
|
|
108
117
|
def step_run_api(self):
|
|
109
118
|
if self._step_run_api is None:
|
|
@@ -8,12 +8,14 @@ from typing import Any, List, Mapping, Optional
|
|
|
8
8
|
|
|
9
9
|
import grpc
|
|
10
10
|
|
|
11
|
+
from hatchet_sdk.v0.client import Client, new_client_raw
|
|
11
12
|
from hatchet_sdk.v0.clients.dispatcher.action_listener import Action
|
|
12
13
|
from hatchet_sdk.v0.clients.dispatcher.dispatcher import (
|
|
13
14
|
ActionListener,
|
|
14
15
|
GetActionListenerRequest,
|
|
15
16
|
new_dispatcher,
|
|
16
17
|
)
|
|
18
|
+
from hatchet_sdk.v0.clients.rest.models.update_worker_request import UpdateWorkerRequest
|
|
17
19
|
from hatchet_sdk.v0.contracts.dispatcher_pb2 import (
|
|
18
20
|
GROUP_KEY_EVENT_TYPE_STARTED,
|
|
19
21
|
STEP_EVENT_TYPE_STARTED,
|
|
@@ -70,9 +72,15 @@ class WorkerActionListenerProcess:
|
|
|
70
72
|
if self.debug:
|
|
71
73
|
logger.setLevel(logging.DEBUG)
|
|
72
74
|
|
|
75
|
+
self.client = new_client_raw(self.config, self.debug)
|
|
76
|
+
|
|
73
77
|
loop = asyncio.get_event_loop()
|
|
74
|
-
loop.add_signal_handler(
|
|
75
|
-
|
|
78
|
+
loop.add_signal_handler(
|
|
79
|
+
signal.SIGINT, lambda: asyncio.create_task(self.pause_task_assignment())
|
|
80
|
+
)
|
|
81
|
+
loop.add_signal_handler(
|
|
82
|
+
signal.SIGTERM, lambda: asyncio.create_task(self.pause_task_assignment())
|
|
83
|
+
)
|
|
76
84
|
loop.add_signal_handler(
|
|
77
85
|
signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully())
|
|
78
86
|
)
|
|
@@ -249,7 +257,15 @@ class WorkerActionListenerProcess:
|
|
|
249
257
|
|
|
250
258
|
self.event_queue.put(STOP_LOOP)
|
|
251
259
|
|
|
260
|
+
async def pause_task_assignment(self) -> None:
|
|
261
|
+
await self.client.rest.aio.worker_api.worker_update(
|
|
262
|
+
worker=self.listener.worker_id,
|
|
263
|
+
update_worker_request=UpdateWorkerRequest(isPaused=True),
|
|
264
|
+
)
|
|
265
|
+
|
|
252
266
|
async def exit_gracefully(self, skip_unregister=False):
|
|
267
|
+
await self.pause_task_assignment()
|
|
268
|
+
|
|
253
269
|
if self.killing:
|
|
254
270
|
return
|
|
255
271
|
|
hatchet_sdk/waits.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from enum import Enum
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
from uuid import uuid4
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field
|
|
7
|
+
|
|
8
|
+
from hatchet_sdk.contracts.v1.shared.condition_pb2 import Action as ProtoAction
|
|
9
|
+
from hatchet_sdk.contracts.v1.shared.condition_pb2 import (
|
|
10
|
+
BaseMatchCondition,
|
|
11
|
+
ParentOverrideMatchCondition,
|
|
12
|
+
SleepMatchCondition,
|
|
13
|
+
UserEventMatchCondition,
|
|
14
|
+
)
|
|
15
|
+
from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
|
|
16
|
+
from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from hatchet_sdk.runnables.task import Task
|
|
20
|
+
from hatchet_sdk.runnables.types import R, TWorkflowInput
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def generate_or_group_id() -> str:
|
|
24
|
+
return str(uuid4())
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Action(Enum):
|
|
28
|
+
CREATE = 0
|
|
29
|
+
QUEUE = 1
|
|
30
|
+
CANCEL = 2
|
|
31
|
+
SKIP = 3
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class BaseCondition(BaseModel):
|
|
35
|
+
readable_data_key: str
|
|
36
|
+
action: Action | None = None
|
|
37
|
+
or_group_id: str = Field(default_factory=generate_or_group_id)
|
|
38
|
+
expression: str | None = None
|
|
39
|
+
|
|
40
|
+
def to_pb(self) -> BaseMatchCondition:
|
|
41
|
+
return BaseMatchCondition(
|
|
42
|
+
readable_data_key=self.readable_data_key,
|
|
43
|
+
action=convert_python_enum_to_proto(self.action, ProtoAction), # type: ignore[arg-type]
|
|
44
|
+
or_group_id=self.or_group_id,
|
|
45
|
+
expression=self.expression,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Condition(ABC):
|
|
50
|
+
def __init__(self, base: BaseCondition):
|
|
51
|
+
self.base = base
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
def to_pb(
|
|
55
|
+
self,
|
|
56
|
+
) -> UserEventMatchCondition | ParentOverrideMatchCondition | SleepMatchCondition:
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class SleepCondition(Condition):
|
|
61
|
+
def __init__(self, duration: Duration) -> None:
|
|
62
|
+
super().__init__(
|
|
63
|
+
BaseCondition(
|
|
64
|
+
readable_data_key=f"sleep:{timedelta_to_expr(duration)}",
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
self.duration = duration
|
|
69
|
+
|
|
70
|
+
def to_pb(self) -> SleepMatchCondition:
|
|
71
|
+
return SleepMatchCondition(
|
|
72
|
+
base=self.base.to_pb(),
|
|
73
|
+
sleep_for=timedelta_to_expr(self.duration),
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class UserEventCondition(Condition):
|
|
78
|
+
def __init__(self, event_key: str, expression: str | None = None) -> None:
|
|
79
|
+
super().__init__(
|
|
80
|
+
BaseCondition(
|
|
81
|
+
readable_data_key=event_key,
|
|
82
|
+
expression=expression,
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
self.event_key = event_key
|
|
87
|
+
self.expression = expression
|
|
88
|
+
|
|
89
|
+
def to_pb(self) -> UserEventMatchCondition:
|
|
90
|
+
return UserEventMatchCondition(
|
|
91
|
+
base=self.base.to_pb(),
|
|
92
|
+
user_event_key=self.event_key,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class ParentCondition(Condition):
|
|
97
|
+
def __init__(
|
|
98
|
+
self, parent: "Task[TWorkflowInput, R]", expression: str | None = None
|
|
99
|
+
) -> None:
|
|
100
|
+
super().__init__(
|
|
101
|
+
BaseCondition(readable_data_key=parent.name, expression=expression)
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
self.parent = parent
|
|
105
|
+
|
|
106
|
+
def to_pb(self) -> ParentOverrideMatchCondition:
|
|
107
|
+
return ParentOverrideMatchCondition(
|
|
108
|
+
base=self.base.to_pb(),
|
|
109
|
+
parent_readable_id=self.parent.name,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class OrGroup:
|
|
114
|
+
def __init__(self, conditions: list[Condition]) -> None:
|
|
115
|
+
self.or_group_id = generate_or_group_id()
|
|
116
|
+
self.conditions = conditions
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def or_(*conditions: Condition) -> OrGroup:
|
|
120
|
+
return OrGroup(conditions=list(conditions))
|
|
@@ -2,12 +2,13 @@ import asyncio
|
|
|
2
2
|
import logging
|
|
3
3
|
import signal
|
|
4
4
|
import time
|
|
5
|
-
from dataclasses import dataclass
|
|
5
|
+
from dataclasses import dataclass
|
|
6
6
|
from multiprocessing import Queue
|
|
7
|
-
from typing import Any,
|
|
7
|
+
from typing import Any, Literal
|
|
8
8
|
|
|
9
9
|
import grpc
|
|
10
10
|
|
|
11
|
+
from hatchet_sdk.client import Client
|
|
11
12
|
from hatchet_sdk.clients.dispatcher.action_listener import (
|
|
12
13
|
Action,
|
|
13
14
|
ActionListener,
|
|
@@ -15,12 +16,18 @@ from hatchet_sdk.clients.dispatcher.action_listener import (
|
|
|
15
16
|
GetActionListenerRequest,
|
|
16
17
|
)
|
|
17
18
|
from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
|
|
19
|
+
from hatchet_sdk.clients.rest.models.update_worker_request import UpdateWorkerRequest
|
|
18
20
|
from hatchet_sdk.config import ClientConfig
|
|
19
21
|
from hatchet_sdk.contracts.dispatcher_pb2 import (
|
|
20
22
|
GROUP_KEY_EVENT_TYPE_STARTED,
|
|
21
23
|
STEP_EVENT_TYPE_STARTED,
|
|
22
24
|
)
|
|
23
25
|
from hatchet_sdk.logger import logger
|
|
26
|
+
from hatchet_sdk.runnables.contextvars import (
|
|
27
|
+
ctx_step_run_id,
|
|
28
|
+
ctx_worker_id,
|
|
29
|
+
ctx_workflow_run_id,
|
|
30
|
+
)
|
|
24
31
|
from hatchet_sdk.utils.backoff import exp_backoff_sleep
|
|
25
32
|
|
|
26
33
|
ACTION_EVENT_RETRY_COUNT = 5
|
|
@@ -42,42 +49,60 @@ BLOCKED_THREAD_WARNING = (
|
|
|
42
49
|
)
|
|
43
50
|
|
|
44
51
|
|
|
45
|
-
def noop_handler() -> None:
|
|
46
|
-
pass
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@dataclass
|
|
50
52
|
class WorkerActionListenerProcess:
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
53
|
+
def __init__(
|
|
54
|
+
self,
|
|
55
|
+
name: str,
|
|
56
|
+
actions: list[str],
|
|
57
|
+
slots: int,
|
|
58
|
+
config: ClientConfig,
|
|
59
|
+
action_queue: "Queue[Action]",
|
|
60
|
+
event_queue: "Queue[ActionEvent | STOP_LOOP_TYPE]",
|
|
61
|
+
handle_kill: bool = True,
|
|
62
|
+
debug: bool = False,
|
|
63
|
+
labels: dict[str, str | int] = {},
|
|
64
|
+
) -> None:
|
|
65
|
+
self.name = name
|
|
66
|
+
self.actions = actions
|
|
67
|
+
self.slots = slots
|
|
68
|
+
self.config = config
|
|
69
|
+
self.action_queue = action_queue
|
|
70
|
+
self.event_queue = event_queue
|
|
71
|
+
self.debug = debug
|
|
72
|
+
self.labels = labels
|
|
73
|
+
self.handle_kill = handle_kill
|
|
74
|
+
|
|
75
|
+
self.listener: ActionListener | None = None
|
|
76
|
+
self.killing = False
|
|
77
|
+
self.action_loop_task: asyncio.Task[None] | None = None
|
|
78
|
+
self.event_send_loop_task: asyncio.Task[None] | None = None
|
|
79
|
+
self.running_step_runs: dict[str, float] = {}
|
|
62
80
|
|
|
63
|
-
killing: bool = field(init=False, default=False)
|
|
64
|
-
|
|
65
|
-
action_loop_task: asyncio.Task[None] | None = field(init=False, default=None)
|
|
66
|
-
event_send_loop_task: asyncio.Task[None] | None = field(init=False, default=None)
|
|
67
|
-
|
|
68
|
-
running_step_runs: dict[str, float] = field(init=False, default_factory=dict)
|
|
69
|
-
|
|
70
|
-
def __post_init__(self) -> None:
|
|
71
81
|
if self.debug:
|
|
72
82
|
logger.setLevel(logging.DEBUG)
|
|
73
83
|
|
|
84
|
+
self.client = Client(config=self.config, debug=self.debug)
|
|
85
|
+
|
|
74
86
|
loop = asyncio.get_event_loop()
|
|
75
|
-
loop.add_signal_handler(
|
|
76
|
-
|
|
87
|
+
loop.add_signal_handler(
|
|
88
|
+
signal.SIGINT, lambda: asyncio.create_task(self.pause_task_assignment())
|
|
89
|
+
)
|
|
90
|
+
loop.add_signal_handler(
|
|
91
|
+
signal.SIGTERM, lambda: asyncio.create_task(self.pause_task_assignment())
|
|
92
|
+
)
|
|
77
93
|
loop.add_signal_handler(
|
|
78
94
|
signal.SIGQUIT, lambda: asyncio.create_task(self.exit_gracefully())
|
|
79
95
|
)
|
|
80
96
|
|
|
97
|
+
async def pause_task_assignment(self) -> None:
|
|
98
|
+
if self.listener is None:
|
|
99
|
+
raise ValueError("listener not started")
|
|
100
|
+
|
|
101
|
+
await self.client.rest.worker_api.worker_update(
|
|
102
|
+
worker=self.listener.worker_id,
|
|
103
|
+
update_worker_request=UpdateWorkerRequest(isPaused=True),
|
|
104
|
+
)
|
|
105
|
+
|
|
81
106
|
async def start(self, retry_attempt: int = 0) -> None:
|
|
82
107
|
if retry_attempt > 5:
|
|
83
108
|
logger.error("could not start action listener")
|
|
@@ -93,8 +118,8 @@ class WorkerActionListenerProcess:
|
|
|
93
118
|
worker_name=self.name,
|
|
94
119
|
services=["default"],
|
|
95
120
|
actions=self.actions,
|
|
96
|
-
|
|
97
|
-
|
|
121
|
+
slots=self.slots,
|
|
122
|
+
raw_labels=self.labels,
|
|
98
123
|
)
|
|
99
124
|
)
|
|
100
125
|
|
|
@@ -190,11 +215,18 @@ class WorkerActionListenerProcess:
|
|
|
190
215
|
return time.time()
|
|
191
216
|
|
|
192
217
|
async def start_action_loop(self) -> None:
|
|
218
|
+
if self.listener is None:
|
|
219
|
+
raise ValueError("listener not started")
|
|
220
|
+
|
|
193
221
|
try:
|
|
194
222
|
async for action in self.listener:
|
|
195
223
|
if action is None:
|
|
196
224
|
break
|
|
197
225
|
|
|
226
|
+
ctx_step_run_id.set(action.step_run_id)
|
|
227
|
+
ctx_workflow_run_id.set(action.workflow_run_id)
|
|
228
|
+
ctx_worker_id.set(action.worker_id)
|
|
229
|
+
|
|
198
230
|
# Process the action here
|
|
199
231
|
match action.action_type:
|
|
200
232
|
case ActionType.START_STEP_RUN:
|
|
@@ -253,6 +285,8 @@ class WorkerActionListenerProcess:
|
|
|
253
285
|
self.event_queue.put(STOP_LOOP)
|
|
254
286
|
|
|
255
287
|
async def exit_gracefully(self) -> None:
|
|
288
|
+
await self.pause_task_assignment()
|
|
289
|
+
|
|
256
290
|
if self.killing:
|
|
257
291
|
return
|
|
258
292
|
|
|
@@ -1,18 +1,17 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
-
from dataclasses import dataclass, field
|
|
4
3
|
from multiprocessing import Queue
|
|
5
4
|
from typing import Any, Literal, TypeVar
|
|
6
5
|
|
|
7
|
-
from hatchet_sdk.client import Client
|
|
6
|
+
from hatchet_sdk.client import Client
|
|
8
7
|
from hatchet_sdk.clients.dispatcher.action_listener import Action
|
|
9
8
|
from hatchet_sdk.config import ClientConfig
|
|
10
9
|
from hatchet_sdk.logger import logger
|
|
10
|
+
from hatchet_sdk.runnables.task import Task
|
|
11
11
|
from hatchet_sdk.utils.typing import WorkflowValidator
|
|
12
12
|
from hatchet_sdk.worker.action_listener_process import ActionEvent
|
|
13
13
|
from hatchet_sdk.worker.runner.runner import Runner
|
|
14
14
|
from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
|
|
15
|
-
from hatchet_sdk.workflow import Step
|
|
16
15
|
|
|
17
16
|
STOP_LOOP_TYPE = Literal["STOP_LOOP"]
|
|
18
17
|
STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP"
|
|
@@ -20,29 +19,40 @@ STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP"
|
|
|
20
19
|
T = TypeVar("T")
|
|
21
20
|
|
|
22
21
|
|
|
23
|
-
@dataclass
|
|
24
22
|
class WorkerActionRunLoopManager:
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
name: str,
|
|
26
|
+
action_registry: dict[str, Task[Any, Any]],
|
|
27
|
+
validator_registry: dict[str, WorkflowValidator],
|
|
28
|
+
slots: int | None,
|
|
29
|
+
config: ClientConfig,
|
|
30
|
+
action_queue: "Queue[Action | STOP_LOOP_TYPE]",
|
|
31
|
+
event_queue: "Queue[ActionEvent]",
|
|
32
|
+
loop: asyncio.AbstractEventLoop,
|
|
33
|
+
handle_kill: bool = True,
|
|
34
|
+
debug: bool = False,
|
|
35
|
+
labels: dict[str, str | int] = {},
|
|
36
|
+
) -> None:
|
|
37
|
+
self.name = name
|
|
38
|
+
self.action_registry = action_registry
|
|
39
|
+
self.validator_registry = validator_registry
|
|
40
|
+
self.slots = slots
|
|
41
|
+
self.config = config
|
|
42
|
+
self.action_queue = action_queue
|
|
43
|
+
self.event_queue = event_queue
|
|
44
|
+
self.loop = loop
|
|
45
|
+
self.handle_kill = handle_kill
|
|
46
|
+
self.debug = debug
|
|
47
|
+
self.labels = labels
|
|
48
|
+
|
|
43
49
|
if self.debug:
|
|
44
50
|
logger.setLevel(logging.DEBUG)
|
|
45
|
-
|
|
51
|
+
|
|
52
|
+
self.killing = False
|
|
53
|
+
self.runner: Runner | None = None
|
|
54
|
+
|
|
55
|
+
self.client = Client(config=self.config, debug=self.debug)
|
|
46
56
|
self.start()
|
|
47
57
|
|
|
48
58
|
def start(self, retry_count: int = 1) -> None:
|
|
@@ -75,11 +85,11 @@ class WorkerActionRunLoopManager:
|
|
|
75
85
|
self.runner = Runner(
|
|
76
86
|
self.name,
|
|
77
87
|
self.event_queue,
|
|
78
|
-
self.
|
|
88
|
+
self.config,
|
|
89
|
+
self.slots,
|
|
79
90
|
self.handle_kill,
|
|
80
91
|
self.action_registry,
|
|
81
92
|
self.validator_registry,
|
|
82
|
-
self.config,
|
|
83
93
|
self.labels,
|
|
84
94
|
)
|
|
85
95
|
|