hatchet-sdk 1.12.3__py3-none-any.whl → 1.13.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hatchet-sdk might be problematic. Click here for more details.
- hatchet_sdk/__init__.py +46 -40
- hatchet_sdk/clients/admin.py +18 -23
- hatchet_sdk/clients/dispatcher/action_listener.py +4 -3
- hatchet_sdk/clients/dispatcher/dispatcher.py +1 -4
- hatchet_sdk/clients/event_ts.py +2 -1
- hatchet_sdk/clients/events.py +16 -12
- hatchet_sdk/clients/listeners/durable_event_listener.py +4 -2
- hatchet_sdk/clients/listeners/pooled_listener.py +2 -2
- hatchet_sdk/clients/listeners/run_event_listener.py +7 -8
- hatchet_sdk/clients/listeners/workflow_listener.py +14 -6
- hatchet_sdk/clients/rest/api_response.py +3 -2
- hatchet_sdk/clients/rest/tenacity_utils.py +6 -8
- hatchet_sdk/config.py +2 -0
- hatchet_sdk/connection.py +10 -4
- hatchet_sdk/context/context.py +170 -46
- hatchet_sdk/context/worker_context.py +4 -7
- hatchet_sdk/contracts/dispatcher_pb2.py +38 -38
- hatchet_sdk/contracts/dispatcher_pb2.pyi +4 -2
- hatchet_sdk/contracts/events_pb2.py +13 -13
- hatchet_sdk/contracts/events_pb2.pyi +4 -2
- hatchet_sdk/contracts/v1/workflows_pb2.py +1 -1
- hatchet_sdk/contracts/v1/workflows_pb2.pyi +2 -2
- hatchet_sdk/exceptions.py +99 -1
- hatchet_sdk/features/cron.py +2 -2
- hatchet_sdk/features/filters.py +3 -3
- hatchet_sdk/features/runs.py +4 -4
- hatchet_sdk/features/scheduled.py +8 -9
- hatchet_sdk/hatchet.py +65 -64
- hatchet_sdk/opentelemetry/instrumentor.py +20 -20
- hatchet_sdk/runnables/action.py +1 -2
- hatchet_sdk/runnables/contextvars.py +19 -0
- hatchet_sdk/runnables/task.py +37 -29
- hatchet_sdk/runnables/types.py +9 -8
- hatchet_sdk/runnables/workflow.py +57 -42
- hatchet_sdk/utils/proto_enums.py +4 -4
- hatchet_sdk/utils/timedelta_to_expression.py +2 -3
- hatchet_sdk/utils/typing.py +11 -17
- hatchet_sdk/waits.py +6 -5
- hatchet_sdk/worker/action_listener_process.py +33 -13
- hatchet_sdk/worker/runner/run_loop_manager.py +15 -11
- hatchet_sdk/worker/runner/runner.py +102 -92
- hatchet_sdk/worker/runner/utils/capture_logs.py +72 -31
- hatchet_sdk/worker/worker.py +29 -25
- hatchet_sdk/workflow_run.py +4 -2
- {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/METADATA +1 -1
- {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/RECORD +48 -48
- {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/WHEEL +0 -0
- {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/entry_points.txt +0 -0
hatchet_sdk/utils/typing.py
CHANGED
|
@@ -1,20 +1,11 @@
|
|
|
1
1
|
import sys
|
|
2
|
-
from
|
|
3
|
-
|
|
4
|
-
Awaitable,
|
|
5
|
-
Coroutine,
|
|
6
|
-
Generator,
|
|
7
|
-
Mapping,
|
|
8
|
-
Type,
|
|
9
|
-
TypeAlias,
|
|
10
|
-
TypeGuard,
|
|
11
|
-
TypeVar,
|
|
12
|
-
)
|
|
2
|
+
from collections.abc import Awaitable, Coroutine, Generator
|
|
3
|
+
from typing import Any, Literal, TypeAlias, TypeGuard, TypeVar
|
|
13
4
|
|
|
14
5
|
from pydantic import BaseModel
|
|
15
6
|
|
|
16
7
|
|
|
17
|
-
def is_basemodel_subclass(model: Any) -> TypeGuard[
|
|
8
|
+
def is_basemodel_subclass(model: Any) -> TypeGuard[type[BaseModel]]:
|
|
18
9
|
try:
|
|
19
10
|
return issubclass(model, BaseModel)
|
|
20
11
|
except TypeError:
|
|
@@ -22,18 +13,21 @@ def is_basemodel_subclass(model: Any) -> TypeGuard[Type[BaseModel]]:
|
|
|
22
13
|
|
|
23
14
|
|
|
24
15
|
class TaskIOValidator(BaseModel):
|
|
25
|
-
workflow_input:
|
|
26
|
-
step_output:
|
|
16
|
+
workflow_input: type[BaseModel] | None = None
|
|
17
|
+
step_output: type[BaseModel] | None = None
|
|
27
18
|
|
|
28
19
|
|
|
29
|
-
JSONSerializableMapping =
|
|
20
|
+
JSONSerializableMapping = dict[str, Any]
|
|
30
21
|
|
|
31
22
|
|
|
32
23
|
_T_co = TypeVar("_T_co", covariant=True)
|
|
33
24
|
|
|
34
25
|
if sys.version_info >= (3, 12):
|
|
35
|
-
AwaitableLike: TypeAlias = Awaitable[_T_co]
|
|
36
|
-
CoroutineLike: TypeAlias = Coroutine[Any, Any, _T_co]
|
|
26
|
+
AwaitableLike: TypeAlias = Awaitable[_T_co]
|
|
27
|
+
CoroutineLike: TypeAlias = Coroutine[Any, Any, _T_co]
|
|
37
28
|
else:
|
|
38
29
|
AwaitableLike: TypeAlias = Generator[Any, None, _T_co] | Awaitable[_T_co]
|
|
39
30
|
CoroutineLike: TypeAlias = Generator[Any, None, _T_co] | Coroutine[Any, Any, _T_co]
|
|
31
|
+
|
|
32
|
+
STOP_LOOP_TYPE = Literal["STOP_LOOP"]
|
|
33
|
+
STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" # Sentinel object to stop the loop
|
hatchet_sdk/waits.py
CHANGED
|
@@ -6,6 +6,7 @@ from uuid import uuid4
|
|
|
6
6
|
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
|
|
9
|
+
from hatchet_sdk.config import ClientConfig
|
|
9
10
|
from hatchet_sdk.contracts.v1.shared.condition_pb2 import Action as ProtoAction
|
|
10
11
|
from hatchet_sdk.contracts.v1.shared.condition_pb2 import (
|
|
11
12
|
BaseMatchCondition,
|
|
@@ -53,7 +54,7 @@ class Condition(ABC):
|
|
|
53
54
|
|
|
54
55
|
@abstractmethod
|
|
55
56
|
def to_proto(
|
|
56
|
-
self,
|
|
57
|
+
self, config: ClientConfig
|
|
57
58
|
) -> UserEventMatchCondition | ParentOverrideMatchCondition | SleepMatchCondition:
|
|
58
59
|
pass
|
|
59
60
|
|
|
@@ -71,7 +72,7 @@ class SleepCondition(Condition):
|
|
|
71
72
|
|
|
72
73
|
self.duration = duration
|
|
73
74
|
|
|
74
|
-
def to_proto(self) -> SleepMatchCondition:
|
|
75
|
+
def to_proto(self, config: ClientConfig) -> SleepMatchCondition:
|
|
75
76
|
return SleepMatchCondition(
|
|
76
77
|
base=self.base.to_proto(),
|
|
77
78
|
sleep_for=timedelta_to_expr(self.duration),
|
|
@@ -95,10 +96,10 @@ class UserEventCondition(Condition):
|
|
|
95
96
|
self.event_key = event_key
|
|
96
97
|
self.expression = expression
|
|
97
98
|
|
|
98
|
-
def to_proto(self) -> UserEventMatchCondition:
|
|
99
|
+
def to_proto(self, config: ClientConfig) -> UserEventMatchCondition:
|
|
99
100
|
return UserEventMatchCondition(
|
|
100
101
|
base=self.base.to_proto(),
|
|
101
|
-
user_event_key=self.event_key,
|
|
102
|
+
user_event_key=config.apply_namespace(self.event_key),
|
|
102
103
|
)
|
|
103
104
|
|
|
104
105
|
|
|
@@ -124,7 +125,7 @@ class ParentCondition(Condition):
|
|
|
124
125
|
|
|
125
126
|
self.parent = parent
|
|
126
127
|
|
|
127
|
-
def to_proto(self) -> ParentOverrideMatchCondition:
|
|
128
|
+
def to_proto(self, config: ClientConfig) -> ParentOverrideMatchCondition:
|
|
128
129
|
return ParentOverrideMatchCondition(
|
|
129
130
|
base=self.base.to_proto(),
|
|
130
131
|
parent_readable_id=self.parent.name,
|
|
@@ -4,9 +4,10 @@ import signal
|
|
|
4
4
|
import time
|
|
5
5
|
from dataclasses import dataclass
|
|
6
6
|
from multiprocessing import Queue
|
|
7
|
-
from typing import Any
|
|
7
|
+
from typing import Any
|
|
8
8
|
|
|
9
9
|
import grpc
|
|
10
|
+
from grpc.aio import UnaryUnaryCall
|
|
10
11
|
|
|
11
12
|
from hatchet_sdk.client import Client
|
|
12
13
|
from hatchet_sdk.clients.dispatcher.action_listener import (
|
|
@@ -19,6 +20,9 @@ from hatchet_sdk.config import ClientConfig
|
|
|
19
20
|
from hatchet_sdk.contracts.dispatcher_pb2 import (
|
|
20
21
|
GROUP_KEY_EVENT_TYPE_STARTED,
|
|
21
22
|
STEP_EVENT_TYPE_STARTED,
|
|
23
|
+
ActionEventResponse,
|
|
24
|
+
GroupKeyActionEvent,
|
|
25
|
+
StepActionEvent,
|
|
22
26
|
)
|
|
23
27
|
from hatchet_sdk.logger import logger
|
|
24
28
|
from hatchet_sdk.runnables.action import Action, ActionType
|
|
@@ -29,6 +33,7 @@ from hatchet_sdk.runnables.contextvars import (
|
|
|
29
33
|
ctx_workflow_run_id,
|
|
30
34
|
)
|
|
31
35
|
from hatchet_sdk.utils.backoff import exp_backoff_sleep
|
|
36
|
+
from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE
|
|
32
37
|
|
|
33
38
|
ACTION_EVENT_RETRY_COUNT = 5
|
|
34
39
|
|
|
@@ -41,9 +46,6 @@ class ActionEvent:
|
|
|
41
46
|
should_not_retry: bool
|
|
42
47
|
|
|
43
48
|
|
|
44
|
-
STOP_LOOP_TYPE = Literal["STOP_LOOP"]
|
|
45
|
-
STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" # Sentinel object to stop the loop
|
|
46
|
-
|
|
47
49
|
BLOCKED_THREAD_WARNING = "THE TIME TO START THE TASK RUN IS TOO LONG, THE EVENT LOOP MAY BE BLOCKED. See https://docs.hatchet.run/blog/warning-event-loop-blocked for details and debugging help."
|
|
48
50
|
|
|
49
51
|
|
|
@@ -56,9 +58,9 @@ class WorkerActionListenerProcess:
|
|
|
56
58
|
config: ClientConfig,
|
|
57
59
|
action_queue: "Queue[Action]",
|
|
58
60
|
event_queue: "Queue[ActionEvent | STOP_LOOP_TYPE]",
|
|
59
|
-
handle_kill: bool
|
|
60
|
-
debug: bool
|
|
61
|
-
labels: dict[str, str | int]
|
|
61
|
+
handle_kill: bool,
|
|
62
|
+
debug: bool,
|
|
63
|
+
labels: dict[str, str | int],
|
|
62
64
|
) -> None:
|
|
63
65
|
self.name = name
|
|
64
66
|
self.actions = actions
|
|
@@ -75,6 +77,14 @@ class WorkerActionListenerProcess:
|
|
|
75
77
|
self.action_loop_task: asyncio.Task[None] | None = None
|
|
76
78
|
self.event_send_loop_task: asyncio.Task[None] | None = None
|
|
77
79
|
self.running_step_runs: dict[str, float] = {}
|
|
80
|
+
self.step_action_events: set[
|
|
81
|
+
asyncio.Task[UnaryUnaryCall[StepActionEvent, ActionEventResponse] | None]
|
|
82
|
+
] = set()
|
|
83
|
+
self.group_key_action_events: set[
|
|
84
|
+
asyncio.Task[
|
|
85
|
+
UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse] | None
|
|
86
|
+
]
|
|
87
|
+
] = set()
|
|
78
88
|
|
|
79
89
|
if self.debug:
|
|
80
90
|
logger.setLevel(logging.DEBUG)
|
|
@@ -144,20 +154,21 @@ class WorkerActionListenerProcess:
|
|
|
144
154
|
break
|
|
145
155
|
|
|
146
156
|
logger.debug(f"tx: event: {event.action.action_id}/{event.type}")
|
|
147
|
-
asyncio.create_task(self.send_event(event))
|
|
157
|
+
t = asyncio.create_task(self.send_event(event))
|
|
158
|
+
self.step_action_events.add(t)
|
|
159
|
+
t.add_done_callback(lambda t: self.step_action_events.discard(t))
|
|
148
160
|
|
|
149
161
|
async def start_blocked_main_loop(self) -> None:
|
|
150
162
|
threshold = 1
|
|
151
163
|
while not self.killing:
|
|
152
164
|
count = 0
|
|
153
|
-
for
|
|
165
|
+
for start_time in self.running_step_runs.values():
|
|
154
166
|
diff = self.now() - start_time
|
|
155
167
|
if diff > threshold:
|
|
156
168
|
count += 1
|
|
157
169
|
|
|
158
170
|
if count > 0:
|
|
159
171
|
logger.warning(f"{BLOCKED_THREAD_WARNING}: Waiting Steps {count}")
|
|
160
|
-
print(asyncio.current_task())
|
|
161
172
|
await asyncio.sleep(1)
|
|
162
173
|
|
|
163
174
|
async def send_event(self, event: ActionEvent, retry_attempt: int = 1) -> None:
|
|
@@ -187,7 +198,7 @@ class WorkerActionListenerProcess:
|
|
|
187
198
|
self.now()
|
|
188
199
|
)
|
|
189
200
|
|
|
190
|
-
asyncio.create_task(
|
|
201
|
+
send_started_event_task = asyncio.create_task(
|
|
191
202
|
self.dispatcher_client.send_step_action_event(
|
|
192
203
|
event.action,
|
|
193
204
|
event.type,
|
|
@@ -195,14 +206,23 @@ class WorkerActionListenerProcess:
|
|
|
195
206
|
event.should_not_retry,
|
|
196
207
|
)
|
|
197
208
|
)
|
|
209
|
+
|
|
210
|
+
self.step_action_events.add(send_started_event_task)
|
|
211
|
+
send_started_event_task.add_done_callback(
|
|
212
|
+
lambda t: self.step_action_events.discard(t)
|
|
213
|
+
)
|
|
198
214
|
case ActionType.CANCEL_STEP_RUN:
|
|
199
215
|
logger.debug("unimplemented event send")
|
|
200
216
|
case ActionType.START_GET_GROUP_KEY:
|
|
201
|
-
asyncio.create_task(
|
|
217
|
+
get_group_key_task = asyncio.create_task(
|
|
202
218
|
self.dispatcher_client.send_group_key_action_event(
|
|
203
219
|
event.action, event.type, event.payload
|
|
204
220
|
)
|
|
205
221
|
)
|
|
222
|
+
self.group_key_action_events.add(get_group_key_task)
|
|
223
|
+
get_group_key_task.add_done_callback(
|
|
224
|
+
lambda t: self.group_key_action_events.discard(t)
|
|
225
|
+
)
|
|
206
226
|
case _:
|
|
207
227
|
logger.error("unknown action type for event send")
|
|
208
228
|
except Exception as e:
|
|
@@ -317,7 +337,7 @@ def worker_action_listener_process(*args: Any, **kwargs: Any) -> None:
|
|
|
317
337
|
process = WorkerActionListenerProcess(*args, **kwargs)
|
|
318
338
|
await process.start()
|
|
319
339
|
# Keep the process running
|
|
320
|
-
while not process.killing:
|
|
340
|
+
while not process.killing: # noqa: ASYNC110
|
|
321
341
|
await asyncio.sleep(0.1)
|
|
322
342
|
|
|
323
343
|
asyncio.run(run())
|
|
@@ -1,19 +1,17 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
3
|
from multiprocessing import Queue
|
|
4
|
-
from typing import Any,
|
|
4
|
+
from typing import Any, TypeVar
|
|
5
5
|
|
|
6
6
|
from hatchet_sdk.client import Client
|
|
7
7
|
from hatchet_sdk.config import ClientConfig
|
|
8
8
|
from hatchet_sdk.logger import logger
|
|
9
9
|
from hatchet_sdk.runnables.action import Action
|
|
10
10
|
from hatchet_sdk.runnables.task import Task
|
|
11
|
+
from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE
|
|
11
12
|
from hatchet_sdk.worker.action_listener_process import ActionEvent
|
|
12
13
|
from hatchet_sdk.worker.runner.runner import Runner
|
|
13
|
-
from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
|
|
14
|
-
|
|
15
|
-
STOP_LOOP_TYPE = Literal["STOP_LOOP"]
|
|
16
|
-
STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP"
|
|
14
|
+
from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender, capture_logs
|
|
17
15
|
|
|
18
16
|
T = TypeVar("T")
|
|
19
17
|
|
|
@@ -28,10 +26,10 @@ class WorkerActionRunLoopManager:
|
|
|
28
26
|
action_queue: "Queue[Action | STOP_LOOP_TYPE]",
|
|
29
27
|
event_queue: "Queue[ActionEvent]",
|
|
30
28
|
loop: asyncio.AbstractEventLoop,
|
|
31
|
-
handle_kill: bool
|
|
32
|
-
debug: bool
|
|
33
|
-
labels: dict[str, str | int]
|
|
34
|
-
lifespan_context: Any | None
|
|
29
|
+
handle_kill: bool,
|
|
30
|
+
debug: bool,
|
|
31
|
+
labels: dict[str, str | int] | None,
|
|
32
|
+
lifespan_context: Any | None,
|
|
35
33
|
) -> None:
|
|
36
34
|
self.name = name
|
|
37
35
|
self.action_registry = action_registry
|
|
@@ -52,15 +50,19 @@ class WorkerActionRunLoopManager:
|
|
|
52
50
|
self.runner: Runner | None = None
|
|
53
51
|
|
|
54
52
|
self.client = Client(config=self.config, debug=self.debug)
|
|
53
|
+
self.start_loop_manager_task: asyncio.Task[None] | None = None
|
|
54
|
+
self.log_sender = AsyncLogSender(self.client.event)
|
|
55
|
+
self.log_task = self.loop.create_task(self.log_sender.consume())
|
|
56
|
+
|
|
55
57
|
self.start()
|
|
56
58
|
|
|
57
59
|
def start(self) -> None:
|
|
58
|
-
|
|
60
|
+
self.start_loop_manager_task = self.loop.create_task(self.aio_start())
|
|
59
61
|
|
|
60
62
|
async def aio_start(self, retry_count: int = 1) -> None:
|
|
61
63
|
await capture_logs(
|
|
62
64
|
self.client.log_interceptor,
|
|
63
|
-
self.
|
|
65
|
+
self.log_sender,
|
|
64
66
|
self._async_start,
|
|
65
67
|
)()
|
|
66
68
|
|
|
@@ -75,6 +77,7 @@ class WorkerActionRunLoopManager:
|
|
|
75
77
|
self.killing = True
|
|
76
78
|
|
|
77
79
|
self.action_queue.put(STOP_LOOP)
|
|
80
|
+
self.log_sender.publish(STOP_LOOP)
|
|
78
81
|
|
|
79
82
|
async def wait_for_tasks(self) -> None:
|
|
80
83
|
if self.runner:
|
|
@@ -89,6 +92,7 @@ class WorkerActionRunLoopManager:
|
|
|
89
92
|
self.action_registry,
|
|
90
93
|
self.labels,
|
|
91
94
|
self.lifespan_context,
|
|
95
|
+
self.log_sender,
|
|
92
96
|
)
|
|
93
97
|
|
|
94
98
|
logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}")
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import contextvars
|
|
3
2
|
import ctypes
|
|
4
3
|
import functools
|
|
5
4
|
import json
|
|
6
|
-
import
|
|
5
|
+
from collections.abc import Callable
|
|
7
6
|
from concurrent.futures import ThreadPoolExecutor
|
|
7
|
+
from contextlib import suppress
|
|
8
8
|
from enum import Enum
|
|
9
9
|
from multiprocessing import Queue
|
|
10
10
|
from threading import Thread, current_thread
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, Literal, cast, overload
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel
|
|
14
14
|
|
|
@@ -30,7 +30,7 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
|
|
|
30
30
|
STEP_EVENT_TYPE_FAILED,
|
|
31
31
|
STEP_EVENT_TYPE_STARTED,
|
|
32
32
|
)
|
|
33
|
-
from hatchet_sdk.exceptions import NonRetryableException
|
|
33
|
+
from hatchet_sdk.exceptions import NonRetryableException, TaskRunError
|
|
34
34
|
from hatchet_sdk.features.runs import RunsClient
|
|
35
35
|
from hatchet_sdk.logger import logger
|
|
36
36
|
from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
|
|
@@ -40,12 +40,17 @@ from hatchet_sdk.runnables.contextvars import (
|
|
|
40
40
|
ctx_worker_id,
|
|
41
41
|
ctx_workflow_run_id,
|
|
42
42
|
spawn_index_lock,
|
|
43
|
+
task_count,
|
|
43
44
|
workflow_spawn_indices,
|
|
44
45
|
)
|
|
45
46
|
from hatchet_sdk.runnables.task import Task
|
|
46
47
|
from hatchet_sdk.runnables.types import R, TWorkflowInput
|
|
47
48
|
from hatchet_sdk.worker.action_listener_process import ActionEvent
|
|
48
|
-
from hatchet_sdk.worker.runner.utils.capture_logs import
|
|
49
|
+
from hatchet_sdk.worker.runner.utils.capture_logs import (
|
|
50
|
+
AsyncLogSender,
|
|
51
|
+
ContextVarToCopy,
|
|
52
|
+
copy_context_vars,
|
|
53
|
+
)
|
|
49
54
|
|
|
50
55
|
|
|
51
56
|
class WorkerStatus(Enum):
|
|
@@ -61,10 +66,11 @@ class Runner:
|
|
|
61
66
|
event_queue: "Queue[ActionEvent]",
|
|
62
67
|
config: ClientConfig,
|
|
63
68
|
slots: int,
|
|
64
|
-
handle_kill: bool
|
|
65
|
-
action_registry: dict[str, Task[TWorkflowInput, R]]
|
|
66
|
-
labels: dict[str, str | int]
|
|
67
|
-
lifespan_context: Any | None
|
|
69
|
+
handle_kill: bool,
|
|
70
|
+
action_registry: dict[str, Task[TWorkflowInput, R]],
|
|
71
|
+
labels: dict[str, str | int] | None,
|
|
72
|
+
lifespan_context: Any | None,
|
|
73
|
+
log_sender: AsyncLogSender,
|
|
68
74
|
):
|
|
69
75
|
# We store the config so we can dynamically create clients for the dispatcher client.
|
|
70
76
|
self.config = config
|
|
@@ -72,13 +78,14 @@ class Runner:
|
|
|
72
78
|
self.slots = slots
|
|
73
79
|
self.tasks: dict[ActionKey, asyncio.Task[Any]] = {} # Store run ids and futures
|
|
74
80
|
self.contexts: dict[ActionKey, Context] = {} # Store run ids and contexts
|
|
75
|
-
self.action_registry = action_registry
|
|
81
|
+
self.action_registry = action_registry or {}
|
|
76
82
|
|
|
77
83
|
self.event_queue = event_queue
|
|
78
84
|
|
|
79
85
|
# The thread pool is used for synchronous functions which need to run concurrently
|
|
80
86
|
self.thread_pool = ThreadPoolExecutor(max_workers=slots)
|
|
81
|
-
self.threads:
|
|
87
|
+
self.threads: dict[ActionKey, Thread] = {} # Store run ids and threads
|
|
88
|
+
self.running_tasks = set[asyncio.Task[Exception | None]]()
|
|
82
89
|
|
|
83
90
|
self.killing = False
|
|
84
91
|
self.handle_kill = handle_kill
|
|
@@ -101,10 +108,11 @@ class Runner:
|
|
|
101
108
|
self.durable_event_listener = DurableEventListener(self.config)
|
|
102
109
|
|
|
103
110
|
self.worker_context = WorkerContext(
|
|
104
|
-
labels=labels, client=Client(config=config).dispatcher
|
|
111
|
+
labels=labels or {}, client=Client(config=config).dispatcher
|
|
105
112
|
)
|
|
106
113
|
|
|
107
114
|
self.lifespan_context = lifespan_context
|
|
115
|
+
self.log_sender = log_sender
|
|
108
116
|
|
|
109
117
|
if self.config.enable_thread_pool_monitoring:
|
|
110
118
|
self.start_background_monitoring()
|
|
@@ -116,67 +124,68 @@ class Runner:
|
|
|
116
124
|
if self.worker_context.id() is None:
|
|
117
125
|
self.worker_context._worker_id = action.worker_id
|
|
118
126
|
|
|
127
|
+
t: asyncio.Task[Exception | None] | None = None
|
|
119
128
|
match action.action_type:
|
|
120
129
|
case ActionType.START_STEP_RUN:
|
|
121
130
|
log = f"run: start step: {action.action_id}/{action.step_run_id}"
|
|
122
131
|
logger.info(log)
|
|
123
|
-
asyncio.create_task(self.handle_start_step_run(action))
|
|
132
|
+
t = asyncio.create_task(self.handle_start_step_run(action))
|
|
124
133
|
case ActionType.CANCEL_STEP_RUN:
|
|
125
134
|
log = f"cancel: step run: {action.action_id}/{action.step_run_id}/{action.retry_count}"
|
|
126
135
|
logger.info(log)
|
|
127
|
-
asyncio.create_task(self.handle_cancel_action(action))
|
|
136
|
+
t = asyncio.create_task(self.handle_cancel_action(action))
|
|
128
137
|
case ActionType.START_GET_GROUP_KEY:
|
|
129
138
|
log = f"run: get group key: {action.action_id}/{action.get_group_key_run_id}"
|
|
130
139
|
logger.info(log)
|
|
131
|
-
asyncio.create_task(self.handle_start_group_key_run(action))
|
|
140
|
+
t = asyncio.create_task(self.handle_start_group_key_run(action))
|
|
132
141
|
case _:
|
|
133
142
|
log = f"unknown action type: {action.action_type}"
|
|
134
143
|
logger.error(log)
|
|
135
144
|
|
|
145
|
+
if t is not None:
|
|
146
|
+
self.running_tasks.add(t)
|
|
147
|
+
t.add_done_callback(lambda task: self.running_tasks.discard(task))
|
|
148
|
+
|
|
136
149
|
def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
|
|
137
150
|
def inner_callback(task: asyncio.Task[Any]) -> None:
|
|
138
151
|
self.cleanup_run_id(action.key)
|
|
139
152
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
output = None
|
|
153
|
+
if task.cancelled():
|
|
154
|
+
return
|
|
143
155
|
|
|
144
|
-
# Get the output from the future
|
|
145
156
|
try:
|
|
146
|
-
|
|
147
|
-
output = task.result()
|
|
157
|
+
output = task.result()
|
|
148
158
|
except Exception as e:
|
|
149
|
-
errored = True
|
|
150
|
-
|
|
151
159
|
should_not_retry = isinstance(e, NonRetryableException)
|
|
152
160
|
|
|
161
|
+
exc = TaskRunError.from_exception(e)
|
|
162
|
+
|
|
153
163
|
# This except is coming from the application itself, so we want to send that to the Hatchet instance
|
|
154
164
|
self.event_queue.put(
|
|
155
165
|
ActionEvent(
|
|
156
166
|
action=action,
|
|
157
167
|
type=STEP_EVENT_TYPE_FAILED,
|
|
158
|
-
payload=
|
|
168
|
+
payload=exc.serialize(),
|
|
159
169
|
should_not_retry=should_not_retry,
|
|
160
170
|
)
|
|
161
171
|
)
|
|
162
172
|
|
|
163
173
|
logger.error(
|
|
164
|
-
f"failed step run: {action.action_id}/{action.step_run_id}"
|
|
174
|
+
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
|
|
165
175
|
)
|
|
166
176
|
|
|
167
|
-
|
|
168
|
-
self.event_queue.put(
|
|
169
|
-
ActionEvent(
|
|
170
|
-
action=action,
|
|
171
|
-
type=STEP_EVENT_TYPE_COMPLETED,
|
|
172
|
-
payload=self.serialize_output(output),
|
|
173
|
-
should_not_retry=False,
|
|
174
|
-
)
|
|
175
|
-
)
|
|
177
|
+
return
|
|
176
178
|
|
|
177
|
-
|
|
178
|
-
|
|
179
|
+
self.event_queue.put(
|
|
180
|
+
ActionEvent(
|
|
181
|
+
action=action,
|
|
182
|
+
type=STEP_EVENT_TYPE_COMPLETED,
|
|
183
|
+
payload=self.serialize_output(output),
|
|
184
|
+
should_not_retry=False,
|
|
179
185
|
)
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
|
|
180
189
|
|
|
181
190
|
return inner_callback
|
|
182
191
|
|
|
@@ -186,51 +195,46 @@ class Runner:
|
|
|
186
195
|
def inner_callback(task: asyncio.Task[Any]) -> None:
|
|
187
196
|
self.cleanup_run_id(action.key)
|
|
188
197
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
output = None
|
|
198
|
+
if task.cancelled():
|
|
199
|
+
return
|
|
192
200
|
|
|
193
|
-
# Get the output from the future
|
|
194
201
|
try:
|
|
195
|
-
|
|
196
|
-
output = task.result()
|
|
202
|
+
output = task.result()
|
|
197
203
|
except Exception as e:
|
|
198
|
-
|
|
204
|
+
exc = TaskRunError.from_exception(e)
|
|
205
|
+
|
|
199
206
|
self.event_queue.put(
|
|
200
207
|
ActionEvent(
|
|
201
208
|
action=action,
|
|
202
209
|
type=GROUP_KEY_EVENT_TYPE_FAILED,
|
|
203
|
-
payload=
|
|
210
|
+
payload=exc.serialize(),
|
|
204
211
|
should_not_retry=False,
|
|
205
212
|
)
|
|
206
213
|
)
|
|
207
214
|
|
|
208
215
|
logger.error(
|
|
209
|
-
f"failed step run: {action.action_id}/{action.step_run_id}"
|
|
216
|
+
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
|
|
210
217
|
)
|
|
211
218
|
|
|
212
|
-
|
|
213
|
-
self.event_queue.put(
|
|
214
|
-
ActionEvent(
|
|
215
|
-
action=action,
|
|
216
|
-
type=GROUP_KEY_EVENT_TYPE_COMPLETED,
|
|
217
|
-
payload=self.serialize_output(output),
|
|
218
|
-
should_not_retry=False,
|
|
219
|
-
)
|
|
220
|
-
)
|
|
219
|
+
return
|
|
221
220
|
|
|
222
|
-
|
|
223
|
-
|
|
221
|
+
self.event_queue.put(
|
|
222
|
+
ActionEvent(
|
|
223
|
+
action=action,
|
|
224
|
+
type=GROUP_KEY_EVENT_TYPE_COMPLETED,
|
|
225
|
+
payload=self.serialize_output(output),
|
|
226
|
+
should_not_retry=False,
|
|
224
227
|
)
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
|
|
225
231
|
|
|
226
232
|
return inner_callback
|
|
227
233
|
|
|
228
234
|
def thread_action_func(
|
|
229
235
|
self, ctx: Context, task: Task[TWorkflowInput, R], action: Action
|
|
230
236
|
) -> R:
|
|
231
|
-
if action.step_run_id:
|
|
232
|
-
self.threads[action.key] = current_thread()
|
|
233
|
-
elif action.get_group_key_run_id:
|
|
237
|
+
if action.step_run_id or action.get_group_key_run_id:
|
|
234
238
|
self.threads[action.key] = current_thread()
|
|
235
239
|
|
|
236
240
|
return task.call(ctx)
|
|
@@ -250,28 +254,36 @@ class Runner:
|
|
|
250
254
|
try:
|
|
251
255
|
if task.is_async_function:
|
|
252
256
|
return await task.aio_call(ctx)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
257
|
+
pfunc = functools.partial(
|
|
258
|
+
# we must copy the context vars to the new thread, as only asyncio natively supports
|
|
259
|
+
# contextvars
|
|
260
|
+
copy_context_vars,
|
|
261
|
+
[
|
|
262
|
+
ContextVarToCopy(
|
|
263
|
+
name="ctx_step_run_id",
|
|
264
|
+
value=action.step_run_id,
|
|
265
|
+
),
|
|
266
|
+
ContextVarToCopy(
|
|
267
|
+
name="ctx_workflow_run_id",
|
|
268
|
+
value=action.workflow_run_id,
|
|
269
|
+
),
|
|
270
|
+
ContextVarToCopy(
|
|
271
|
+
name="ctx_worker_id",
|
|
272
|
+
value=action.worker_id,
|
|
273
|
+
),
|
|
274
|
+
ContextVarToCopy(
|
|
275
|
+
name="ctx_action_key",
|
|
276
|
+
value=action.key,
|
|
277
|
+
),
|
|
278
|
+
],
|
|
279
|
+
self.thread_action_func,
|
|
280
|
+
ctx,
|
|
281
|
+
task,
|
|
282
|
+
action,
|
|
273
283
|
)
|
|
274
|
-
|
|
284
|
+
|
|
285
|
+
loop = asyncio.get_event_loop()
|
|
286
|
+
return await loop.run_in_executor(self.thread_pool, pfunc)
|
|
275
287
|
finally:
|
|
276
288
|
self.cleanup_run_id(action.key)
|
|
277
289
|
|
|
@@ -295,7 +307,7 @@ class Runner:
|
|
|
295
307
|
while True:
|
|
296
308
|
await self.log_thread_pool_status()
|
|
297
309
|
|
|
298
|
-
for key in self.threads
|
|
310
|
+
for key in self.threads:
|
|
299
311
|
if key not in self.tasks:
|
|
300
312
|
logger.debug(f"Potential zombie thread found for key {key}")
|
|
301
313
|
|
|
@@ -350,6 +362,7 @@ class Runner:
|
|
|
350
362
|
worker=self.worker_context,
|
|
351
363
|
runs_client=self.runs_client,
|
|
352
364
|
lifespan_context=self.lifespan_context,
|
|
365
|
+
log_sender=self.log_sender,
|
|
353
366
|
)
|
|
354
367
|
|
|
355
368
|
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
|
|
@@ -361,7 +374,8 @@ class Runner:
|
|
|
361
374
|
|
|
362
375
|
if action_func:
|
|
363
376
|
context = self.create_context(
|
|
364
|
-
action,
|
|
377
|
+
action,
|
|
378
|
+
True if action_func.is_durable else False, # noqa: SIM210
|
|
365
379
|
)
|
|
366
380
|
|
|
367
381
|
self.contexts[action.key] = context
|
|
@@ -382,11 +396,12 @@ class Runner:
|
|
|
382
396
|
task.add_done_callback(self.step_run_callback(action))
|
|
383
397
|
self.tasks[action.key] = task
|
|
384
398
|
|
|
385
|
-
|
|
399
|
+
task_count.increment()
|
|
400
|
+
|
|
401
|
+
## FIXME: Handle cancelled exceptions and other special exceptions
|
|
402
|
+
## that we don't want to suppress here
|
|
403
|
+
with suppress(Exception):
|
|
386
404
|
await task
|
|
387
|
-
except Exception:
|
|
388
|
-
# do nothing, this should be caught in the callback
|
|
389
|
-
pass
|
|
390
405
|
|
|
391
406
|
## Once the step run completes, we need to remove the workflow spawn index
|
|
392
407
|
## so we don't leak memory
|
|
@@ -444,7 +459,7 @@ class Runner:
|
|
|
444
459
|
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
|
|
445
460
|
if res == 0:
|
|
446
461
|
raise ValueError("Invalid thread ID")
|
|
447
|
-
|
|
462
|
+
if res != 1:
|
|
448
463
|
logger.error("PyThreadState_SetAsyncExc failed")
|
|
449
464
|
|
|
450
465
|
# Call with exception set to 0 is needed to cleanup properly.
|
|
@@ -505,8 +520,3 @@ class Runner:
|
|
|
505
520
|
logger.info(f"waiting for {running} tasks to finish...")
|
|
506
521
|
await asyncio.sleep(1)
|
|
507
522
|
running = len(self.tasks.keys())
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
def pretty_format_exception(message: str, e: Exception) -> str:
|
|
511
|
-
trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
512
|
-
return f"{message}\n{trace}"
|