hatchet-sdk 1.12.3__py3-none-any.whl → 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hatchet-sdk might be problematic. Click here for more details.
- hatchet_sdk/__init__.py +54 -40
- hatchet_sdk/clients/admin.py +18 -23
- hatchet_sdk/clients/dispatcher/action_listener.py +4 -3
- hatchet_sdk/clients/dispatcher/dispatcher.py +1 -4
- hatchet_sdk/clients/event_ts.py +2 -1
- hatchet_sdk/clients/events.py +16 -12
- hatchet_sdk/clients/listeners/durable_event_listener.py +4 -2
- hatchet_sdk/clients/listeners/pooled_listener.py +2 -2
- hatchet_sdk/clients/listeners/run_event_listener.py +7 -8
- hatchet_sdk/clients/listeners/workflow_listener.py +14 -6
- hatchet_sdk/clients/rest/api_response.py +3 -2
- hatchet_sdk/clients/rest/models/semaphore_slots.py +1 -1
- hatchet_sdk/clients/rest/models/v1_task_summary.py +5 -0
- hatchet_sdk/clients/rest/models/v1_workflow_run_details.py +11 -1
- hatchet_sdk/clients/rest/models/workflow_version.py +5 -0
- hatchet_sdk/clients/rest/tenacity_utils.py +6 -8
- hatchet_sdk/config.py +2 -0
- hatchet_sdk/connection.py +10 -4
- hatchet_sdk/context/context.py +170 -46
- hatchet_sdk/context/worker_context.py +4 -7
- hatchet_sdk/contracts/dispatcher_pb2.py +38 -38
- hatchet_sdk/contracts/dispatcher_pb2.pyi +4 -2
- hatchet_sdk/contracts/events_pb2.py +13 -13
- hatchet_sdk/contracts/events_pb2.pyi +4 -2
- hatchet_sdk/contracts/v1/workflows_pb2.py +1 -1
- hatchet_sdk/contracts/v1/workflows_pb2.pyi +2 -2
- hatchet_sdk/exceptions.py +103 -1
- hatchet_sdk/features/cron.py +2 -2
- hatchet_sdk/features/filters.py +12 -21
- hatchet_sdk/features/runs.py +4 -4
- hatchet_sdk/features/scheduled.py +8 -9
- hatchet_sdk/hatchet.py +65 -64
- hatchet_sdk/opentelemetry/instrumentor.py +20 -20
- hatchet_sdk/runnables/action.py +1 -2
- hatchet_sdk/runnables/contextvars.py +19 -0
- hatchet_sdk/runnables/task.py +37 -29
- hatchet_sdk/runnables/types.py +9 -8
- hatchet_sdk/runnables/workflow.py +57 -42
- hatchet_sdk/utils/proto_enums.py +4 -4
- hatchet_sdk/utils/timedelta_to_expression.py +2 -3
- hatchet_sdk/utils/typing.py +11 -17
- hatchet_sdk/v0/__init__.py +7 -7
- hatchet_sdk/v0/clients/admin.py +7 -7
- hatchet_sdk/v0/clients/dispatcher/action_listener.py +8 -8
- hatchet_sdk/v0/clients/dispatcher/dispatcher.py +9 -9
- hatchet_sdk/v0/clients/events.py +3 -3
- hatchet_sdk/v0/clients/rest/tenacity_utils.py +1 -1
- hatchet_sdk/v0/clients/run_event_listener.py +3 -3
- hatchet_sdk/v0/clients/workflow_listener.py +5 -5
- hatchet_sdk/v0/context/context.py +6 -6
- hatchet_sdk/v0/hatchet.py +4 -4
- hatchet_sdk/v0/opentelemetry/instrumentor.py +1 -1
- hatchet_sdk/v0/rate_limit.py +1 -1
- hatchet_sdk/v0/v2/callable.py +4 -4
- hatchet_sdk/v0/v2/concurrency.py +2 -2
- hatchet_sdk/v0/v2/hatchet.py +3 -3
- hatchet_sdk/v0/worker/action_listener_process.py +6 -6
- hatchet_sdk/v0/worker/runner/run_loop_manager.py +1 -1
- hatchet_sdk/v0/worker/runner/runner.py +10 -10
- hatchet_sdk/v0/worker/runner/utils/capture_logs.py +1 -1
- hatchet_sdk/v0/worker/worker.py +2 -2
- hatchet_sdk/v0/workflow.py +3 -3
- hatchet_sdk/waits.py +6 -5
- hatchet_sdk/worker/action_listener_process.py +33 -13
- hatchet_sdk/worker/runner/run_loop_manager.py +15 -11
- hatchet_sdk/worker/runner/runner.py +142 -80
- hatchet_sdk/worker/runner/utils/capture_logs.py +72 -31
- hatchet_sdk/worker/worker.py +30 -26
- hatchet_sdk/workflow_run.py +4 -2
- {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.14.0.dist-info}/METADATA +1 -1
- {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.14.0.dist-info}/RECORD +73 -83
- hatchet_sdk/v0/contracts/dispatcher_pb2.py +0 -102
- hatchet_sdk/v0/contracts/dispatcher_pb2.pyi +0 -387
- hatchet_sdk/v0/contracts/dispatcher_pb2_grpc.py +0 -621
- hatchet_sdk/v0/contracts/events_pb2.py +0 -46
- hatchet_sdk/v0/contracts/events_pb2.pyi +0 -87
- hatchet_sdk/v0/contracts/events_pb2_grpc.py +0 -274
- hatchet_sdk/v0/contracts/workflows_pb2.py +0 -80
- hatchet_sdk/v0/contracts/workflows_pb2.pyi +0 -312
- hatchet_sdk/v0/contracts/workflows_pb2_grpc.py +0 -277
- hatchet_sdk/v0/logger.py +0 -13
- {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.14.0.dist-info}/WHEEL +0 -0
- {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.14.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
import asyncio
|
|
2
|
-
import contextvars
|
|
3
2
|
import ctypes
|
|
4
3
|
import functools
|
|
5
4
|
import json
|
|
6
|
-
import
|
|
5
|
+
from collections.abc import Callable
|
|
7
6
|
from concurrent.futures import ThreadPoolExecutor
|
|
7
|
+
from contextlib import suppress
|
|
8
8
|
from enum import Enum
|
|
9
9
|
from multiprocessing import Queue
|
|
10
10
|
from threading import Thread, current_thread
|
|
11
|
-
from typing import Any,
|
|
11
|
+
from typing import Any, Literal, cast, overload
|
|
12
12
|
|
|
13
13
|
from pydantic import BaseModel
|
|
14
14
|
|
|
@@ -30,7 +30,11 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
|
|
|
30
30
|
STEP_EVENT_TYPE_FAILED,
|
|
31
31
|
STEP_EVENT_TYPE_STARTED,
|
|
32
32
|
)
|
|
33
|
-
from hatchet_sdk.exceptions import
|
|
33
|
+
from hatchet_sdk.exceptions import (
|
|
34
|
+
IllegalTaskOutputError,
|
|
35
|
+
NonRetryableException,
|
|
36
|
+
TaskRunError,
|
|
37
|
+
)
|
|
34
38
|
from hatchet_sdk.features.runs import RunsClient
|
|
35
39
|
from hatchet_sdk.logger import logger
|
|
36
40
|
from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
|
|
@@ -40,12 +44,17 @@ from hatchet_sdk.runnables.contextvars import (
|
|
|
40
44
|
ctx_worker_id,
|
|
41
45
|
ctx_workflow_run_id,
|
|
42
46
|
spawn_index_lock,
|
|
47
|
+
task_count,
|
|
43
48
|
workflow_spawn_indices,
|
|
44
49
|
)
|
|
45
50
|
from hatchet_sdk.runnables.task import Task
|
|
46
51
|
from hatchet_sdk.runnables.types import R, TWorkflowInput
|
|
47
52
|
from hatchet_sdk.worker.action_listener_process import ActionEvent
|
|
48
|
-
from hatchet_sdk.worker.runner.utils.capture_logs import
|
|
53
|
+
from hatchet_sdk.worker.runner.utils.capture_logs import (
|
|
54
|
+
AsyncLogSender,
|
|
55
|
+
ContextVarToCopy,
|
|
56
|
+
copy_context_vars,
|
|
57
|
+
)
|
|
49
58
|
|
|
50
59
|
|
|
51
60
|
class WorkerStatus(Enum):
|
|
@@ -61,10 +70,11 @@ class Runner:
|
|
|
61
70
|
event_queue: "Queue[ActionEvent]",
|
|
62
71
|
config: ClientConfig,
|
|
63
72
|
slots: int,
|
|
64
|
-
handle_kill: bool
|
|
65
|
-
action_registry: dict[str, Task[TWorkflowInput, R]]
|
|
66
|
-
labels: dict[str, str | int]
|
|
67
|
-
lifespan_context: Any | None
|
|
73
|
+
handle_kill: bool,
|
|
74
|
+
action_registry: dict[str, Task[TWorkflowInput, R]],
|
|
75
|
+
labels: dict[str, str | int] | None,
|
|
76
|
+
lifespan_context: Any | None,
|
|
77
|
+
log_sender: AsyncLogSender,
|
|
68
78
|
):
|
|
69
79
|
# We store the config so we can dynamically create clients for the dispatcher client.
|
|
70
80
|
self.config = config
|
|
@@ -72,13 +82,14 @@ class Runner:
|
|
|
72
82
|
self.slots = slots
|
|
73
83
|
self.tasks: dict[ActionKey, asyncio.Task[Any]] = {} # Store run ids and futures
|
|
74
84
|
self.contexts: dict[ActionKey, Context] = {} # Store run ids and contexts
|
|
75
|
-
self.action_registry = action_registry
|
|
85
|
+
self.action_registry = action_registry or {}
|
|
76
86
|
|
|
77
87
|
self.event_queue = event_queue
|
|
78
88
|
|
|
79
89
|
# The thread pool is used for synchronous functions which need to run concurrently
|
|
80
90
|
self.thread_pool = ThreadPoolExecutor(max_workers=slots)
|
|
81
|
-
self.threads:
|
|
91
|
+
self.threads: dict[ActionKey, Thread] = {} # Store run ids and threads
|
|
92
|
+
self.running_tasks = set[asyncio.Task[Exception | None]]()
|
|
82
93
|
|
|
83
94
|
self.killing = False
|
|
84
95
|
self.handle_kill = handle_kill
|
|
@@ -101,10 +112,11 @@ class Runner:
|
|
|
101
112
|
self.durable_event_listener = DurableEventListener(self.config)
|
|
102
113
|
|
|
103
114
|
self.worker_context = WorkerContext(
|
|
104
|
-
labels=labels, client=Client(config=config).dispatcher
|
|
115
|
+
labels=labels or {}, client=Client(config=config).dispatcher
|
|
105
116
|
)
|
|
106
117
|
|
|
107
118
|
self.lifespan_context = lifespan_context
|
|
119
|
+
self.log_sender = log_sender
|
|
108
120
|
|
|
109
121
|
if self.config.enable_thread_pool_monitoring:
|
|
110
122
|
self.start_background_monitoring()
|
|
@@ -116,68 +128,90 @@ class Runner:
|
|
|
116
128
|
if self.worker_context.id() is None:
|
|
117
129
|
self.worker_context._worker_id = action.worker_id
|
|
118
130
|
|
|
131
|
+
t: asyncio.Task[Exception | None] | None = None
|
|
119
132
|
match action.action_type:
|
|
120
133
|
case ActionType.START_STEP_RUN:
|
|
121
134
|
log = f"run: start step: {action.action_id}/{action.step_run_id}"
|
|
122
135
|
logger.info(log)
|
|
123
|
-
asyncio.create_task(self.handle_start_step_run(action))
|
|
136
|
+
t = asyncio.create_task(self.handle_start_step_run(action))
|
|
124
137
|
case ActionType.CANCEL_STEP_RUN:
|
|
125
138
|
log = f"cancel: step run: {action.action_id}/{action.step_run_id}/{action.retry_count}"
|
|
126
139
|
logger.info(log)
|
|
127
|
-
asyncio.create_task(self.handle_cancel_action(action))
|
|
140
|
+
t = asyncio.create_task(self.handle_cancel_action(action))
|
|
128
141
|
case ActionType.START_GET_GROUP_KEY:
|
|
129
142
|
log = f"run: get group key: {action.action_id}/{action.get_group_key_run_id}"
|
|
130
143
|
logger.info(log)
|
|
131
|
-
asyncio.create_task(self.handle_start_group_key_run(action))
|
|
144
|
+
t = asyncio.create_task(self.handle_start_group_key_run(action))
|
|
132
145
|
case _:
|
|
133
146
|
log = f"unknown action type: {action.action_type}"
|
|
134
147
|
logger.error(log)
|
|
135
148
|
|
|
149
|
+
if t is not None:
|
|
150
|
+
self.running_tasks.add(t)
|
|
151
|
+
t.add_done_callback(lambda task: self.running_tasks.discard(task))
|
|
152
|
+
|
|
136
153
|
def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
|
|
137
154
|
def inner_callback(task: asyncio.Task[Any]) -> None:
|
|
138
155
|
self.cleanup_run_id(action.key)
|
|
139
156
|
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
output = None
|
|
157
|
+
if task.cancelled():
|
|
158
|
+
return
|
|
143
159
|
|
|
144
|
-
# Get the output from the future
|
|
145
160
|
try:
|
|
146
|
-
|
|
147
|
-
output = task.result()
|
|
161
|
+
output = task.result()
|
|
148
162
|
except Exception as e:
|
|
149
|
-
errored = True
|
|
150
|
-
|
|
151
163
|
should_not_retry = isinstance(e, NonRetryableException)
|
|
152
164
|
|
|
165
|
+
exc = TaskRunError.from_exception(e)
|
|
166
|
+
|
|
153
167
|
# This except is coming from the application itself, so we want to send that to the Hatchet instance
|
|
154
168
|
self.event_queue.put(
|
|
155
169
|
ActionEvent(
|
|
156
170
|
action=action,
|
|
157
171
|
type=STEP_EVENT_TYPE_FAILED,
|
|
158
|
-
payload=
|
|
172
|
+
payload=exc.serialize(),
|
|
159
173
|
should_not_retry=should_not_retry,
|
|
160
174
|
)
|
|
161
175
|
)
|
|
162
176
|
|
|
163
|
-
logger.error
|
|
164
|
-
|
|
177
|
+
log_with_level = logger.info if should_not_retry else logger.error
|
|
178
|
+
|
|
179
|
+
log_with_level(
|
|
180
|
+
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
|
|
165
181
|
)
|
|
166
182
|
|
|
167
|
-
|
|
183
|
+
return
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
output = self.serialize_output(output)
|
|
187
|
+
|
|
168
188
|
self.event_queue.put(
|
|
169
189
|
ActionEvent(
|
|
170
190
|
action=action,
|
|
171
191
|
type=STEP_EVENT_TYPE_COMPLETED,
|
|
172
|
-
payload=
|
|
192
|
+
payload=output,
|
|
193
|
+
should_not_retry=False,
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
except IllegalTaskOutputError as e:
|
|
197
|
+
exc = TaskRunError.from_exception(e)
|
|
198
|
+
self.event_queue.put(
|
|
199
|
+
ActionEvent(
|
|
200
|
+
action=action,
|
|
201
|
+
type=STEP_EVENT_TYPE_FAILED,
|
|
202
|
+
payload=exc.serialize(),
|
|
173
203
|
should_not_retry=False,
|
|
174
204
|
)
|
|
175
205
|
)
|
|
176
206
|
|
|
177
|
-
logger.
|
|
178
|
-
f"
|
|
207
|
+
logger.error(
|
|
208
|
+
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
|
|
179
209
|
)
|
|
180
210
|
|
|
211
|
+
return
|
|
212
|
+
|
|
213
|
+
logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
|
|
214
|
+
|
|
181
215
|
return inner_callback
|
|
182
216
|
|
|
183
217
|
def group_key_run_callback(
|
|
@@ -186,51 +220,65 @@ class Runner:
|
|
|
186
220
|
def inner_callback(task: asyncio.Task[Any]) -> None:
|
|
187
221
|
self.cleanup_run_id(action.key)
|
|
188
222
|
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
output = None
|
|
223
|
+
if task.cancelled():
|
|
224
|
+
return
|
|
192
225
|
|
|
193
|
-
# Get the output from the future
|
|
194
226
|
try:
|
|
195
|
-
|
|
196
|
-
output = task.result()
|
|
227
|
+
output = task.result()
|
|
197
228
|
except Exception as e:
|
|
198
|
-
|
|
229
|
+
exc = TaskRunError.from_exception(e)
|
|
230
|
+
|
|
199
231
|
self.event_queue.put(
|
|
200
232
|
ActionEvent(
|
|
201
233
|
action=action,
|
|
202
234
|
type=GROUP_KEY_EVENT_TYPE_FAILED,
|
|
203
|
-
payload=
|
|
235
|
+
payload=exc.serialize(),
|
|
204
236
|
should_not_retry=False,
|
|
205
237
|
)
|
|
206
238
|
)
|
|
207
239
|
|
|
208
240
|
logger.error(
|
|
209
|
-
f"failed step run: {action.action_id}/{action.step_run_id}"
|
|
241
|
+
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
|
|
210
242
|
)
|
|
211
243
|
|
|
212
|
-
|
|
244
|
+
return
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
output = self.serialize_output(output)
|
|
248
|
+
|
|
213
249
|
self.event_queue.put(
|
|
214
250
|
ActionEvent(
|
|
215
251
|
action=action,
|
|
216
252
|
type=GROUP_KEY_EVENT_TYPE_COMPLETED,
|
|
217
|
-
payload=
|
|
253
|
+
payload=output,
|
|
254
|
+
should_not_retry=False,
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
except IllegalTaskOutputError as e:
|
|
258
|
+
exc = TaskRunError.from_exception(e)
|
|
259
|
+
self.event_queue.put(
|
|
260
|
+
ActionEvent(
|
|
261
|
+
action=action,
|
|
262
|
+
type=STEP_EVENT_TYPE_FAILED,
|
|
263
|
+
payload=exc.serialize(),
|
|
218
264
|
should_not_retry=False,
|
|
219
265
|
)
|
|
220
266
|
)
|
|
221
267
|
|
|
222
|
-
logger.
|
|
223
|
-
f"
|
|
268
|
+
logger.error(
|
|
269
|
+
f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
|
|
224
270
|
)
|
|
225
271
|
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
|
|
275
|
+
|
|
226
276
|
return inner_callback
|
|
227
277
|
|
|
228
278
|
def thread_action_func(
|
|
229
279
|
self, ctx: Context, task: Task[TWorkflowInput, R], action: Action
|
|
230
280
|
) -> R:
|
|
231
|
-
if action.step_run_id:
|
|
232
|
-
self.threads[action.key] = current_thread()
|
|
233
|
-
elif action.get_group_key_run_id:
|
|
281
|
+
if action.step_run_id or action.get_group_key_run_id:
|
|
234
282
|
self.threads[action.key] = current_thread()
|
|
235
283
|
|
|
236
284
|
return task.call(ctx)
|
|
@@ -250,28 +298,36 @@ class Runner:
|
|
|
250
298
|
try:
|
|
251
299
|
if task.is_async_function:
|
|
252
300
|
return await task.aio_call(ctx)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
301
|
+
pfunc = functools.partial(
|
|
302
|
+
# we must copy the context vars to the new thread, as only asyncio natively supports
|
|
303
|
+
# contextvars
|
|
304
|
+
copy_context_vars,
|
|
305
|
+
[
|
|
306
|
+
ContextVarToCopy(
|
|
307
|
+
name="ctx_step_run_id",
|
|
308
|
+
value=action.step_run_id,
|
|
309
|
+
),
|
|
310
|
+
ContextVarToCopy(
|
|
311
|
+
name="ctx_workflow_run_id",
|
|
312
|
+
value=action.workflow_run_id,
|
|
313
|
+
),
|
|
314
|
+
ContextVarToCopy(
|
|
315
|
+
name="ctx_worker_id",
|
|
316
|
+
value=action.worker_id,
|
|
317
|
+
),
|
|
318
|
+
ContextVarToCopy(
|
|
319
|
+
name="ctx_action_key",
|
|
320
|
+
value=action.key,
|
|
321
|
+
),
|
|
322
|
+
],
|
|
323
|
+
self.thread_action_func,
|
|
324
|
+
ctx,
|
|
325
|
+
task,
|
|
326
|
+
action,
|
|
273
327
|
)
|
|
274
|
-
|
|
328
|
+
|
|
329
|
+
loop = asyncio.get_event_loop()
|
|
330
|
+
return await loop.run_in_executor(self.thread_pool, pfunc)
|
|
275
331
|
finally:
|
|
276
332
|
self.cleanup_run_id(action.key)
|
|
277
333
|
|
|
@@ -295,7 +351,7 @@ class Runner:
|
|
|
295
351
|
while True:
|
|
296
352
|
await self.log_thread_pool_status()
|
|
297
353
|
|
|
298
|
-
for key in self.threads
|
|
354
|
+
for key in self.threads:
|
|
299
355
|
if key not in self.tasks:
|
|
300
356
|
logger.debug(f"Potential zombie thread found for key {key}")
|
|
301
357
|
|
|
@@ -350,6 +406,7 @@ class Runner:
|
|
|
350
406
|
worker=self.worker_context,
|
|
351
407
|
runs_client=self.runs_client,
|
|
352
408
|
lifespan_context=self.lifespan_context,
|
|
409
|
+
log_sender=self.log_sender,
|
|
353
410
|
)
|
|
354
411
|
|
|
355
412
|
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
|
|
@@ -361,7 +418,8 @@ class Runner:
|
|
|
361
418
|
|
|
362
419
|
if action_func:
|
|
363
420
|
context = self.create_context(
|
|
364
|
-
action,
|
|
421
|
+
action,
|
|
422
|
+
True if action_func.is_durable else False, # noqa: SIM210
|
|
365
423
|
)
|
|
366
424
|
|
|
367
425
|
self.contexts[action.key] = context
|
|
@@ -382,11 +440,12 @@ class Runner:
|
|
|
382
440
|
task.add_done_callback(self.step_run_callback(action))
|
|
383
441
|
self.tasks[action.key] = task
|
|
384
442
|
|
|
385
|
-
|
|
443
|
+
task_count.increment()
|
|
444
|
+
|
|
445
|
+
## FIXME: Handle cancelled exceptions and other special exceptions
|
|
446
|
+
## that we don't want to suppress here
|
|
447
|
+
with suppress(Exception):
|
|
386
448
|
await task
|
|
387
|
-
except Exception:
|
|
388
|
-
# do nothing, this should be caught in the callback
|
|
389
|
-
pass
|
|
390
449
|
|
|
391
450
|
## Once the step run completes, we need to remove the workflow spawn index
|
|
392
451
|
## so we don't leak memory
|
|
@@ -444,7 +503,7 @@ class Runner:
|
|
|
444
503
|
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
|
|
445
504
|
if res == 0:
|
|
446
505
|
raise ValueError("Invalid thread ID")
|
|
447
|
-
|
|
506
|
+
if res != 1:
|
|
448
507
|
logger.error("PyThreadState_SetAsyncExc failed")
|
|
449
508
|
|
|
450
509
|
# Call with exception set to 0 is needed to cleanup properly.
|
|
@@ -487,8 +546,16 @@ class Runner:
|
|
|
487
546
|
self.cleanup_run_id(key)
|
|
488
547
|
|
|
489
548
|
def serialize_output(self, output: Any) -> str:
|
|
549
|
+
if not output:
|
|
550
|
+
return ""
|
|
551
|
+
|
|
490
552
|
if isinstance(output, BaseModel):
|
|
491
|
-
|
|
553
|
+
output = output.model_dump()
|
|
554
|
+
|
|
555
|
+
if not isinstance(output, dict):
|
|
556
|
+
raise IllegalTaskOutputError(
|
|
557
|
+
f"Tasks must return either a dictionary or a Pydantic BaseModel which can be serialized to a JSON object. Got object of type {type(output)} instead."
|
|
558
|
+
)
|
|
492
559
|
|
|
493
560
|
if output is not None:
|
|
494
561
|
try:
|
|
@@ -505,8 +572,3 @@ class Runner:
|
|
|
505
572
|
logger.info(f"waiting for {running} tasks to finish...")
|
|
506
573
|
await asyncio.sleep(1)
|
|
507
574
|
running = len(self.tasks.keys())
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
def pretty_format_exception(message: str, e: Exception) -> str:
|
|
511
|
-
trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
|
512
|
-
return f"{message}\n{trace}"
|
|
@@ -1,73 +1,114 @@
|
|
|
1
|
+
import asyncio
|
|
1
2
|
import functools
|
|
2
3
|
import logging
|
|
3
|
-
from
|
|
4
|
-
from contextvars import ContextVar
|
|
4
|
+
from collections.abc import Awaitable, Callable
|
|
5
5
|
from io import StringIO
|
|
6
|
-
from typing import
|
|
6
|
+
from typing import Literal, ParamSpec, TypeVar
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel
|
|
7
9
|
|
|
8
10
|
from hatchet_sdk.clients.events import EventClient
|
|
9
11
|
from hatchet_sdk.logger import logger
|
|
10
|
-
from hatchet_sdk.runnables.contextvars import
|
|
12
|
+
from hatchet_sdk.runnables.contextvars import (
|
|
13
|
+
ctx_action_key,
|
|
14
|
+
ctx_step_run_id,
|
|
15
|
+
ctx_worker_id,
|
|
16
|
+
ctx_workflow_run_id,
|
|
17
|
+
)
|
|
18
|
+
from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE
|
|
11
19
|
|
|
12
20
|
T = TypeVar("T")
|
|
13
21
|
P = ParamSpec("P")
|
|
14
22
|
|
|
15
23
|
|
|
24
|
+
class ContextVarToCopy(BaseModel):
|
|
25
|
+
name: Literal[
|
|
26
|
+
"ctx_workflow_run_id", "ctx_step_run_id", "ctx_action_key", "ctx_worker_id"
|
|
27
|
+
]
|
|
28
|
+
value: str | None
|
|
29
|
+
|
|
30
|
+
|
|
16
31
|
def copy_context_vars(
|
|
17
|
-
ctx_vars:
|
|
32
|
+
ctx_vars: list[ContextVarToCopy],
|
|
18
33
|
func: Callable[P, T],
|
|
19
34
|
*args: P.args,
|
|
20
35
|
**kwargs: P.kwargs,
|
|
21
36
|
) -> T:
|
|
22
|
-
for var
|
|
23
|
-
var.
|
|
37
|
+
for var in ctx_vars:
|
|
38
|
+
if var.name == "ctx_workflow_run_id":
|
|
39
|
+
ctx_workflow_run_id.set(var.value)
|
|
40
|
+
elif var.name == "ctx_step_run_id":
|
|
41
|
+
ctx_step_run_id.set(var.value)
|
|
42
|
+
elif var.name == "ctx_action_key":
|
|
43
|
+
ctx_action_key.set(var.value)
|
|
44
|
+
elif var.name == "ctx_worker_id":
|
|
45
|
+
ctx_worker_id.set(var.value)
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError(f"Unknown context variable name: {var.name}")
|
|
48
|
+
|
|
24
49
|
return func(*args, **kwargs)
|
|
25
50
|
|
|
26
51
|
|
|
27
|
-
class
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def filter(self, record: logging.LogRecord) -> bool:
|
|
31
|
-
## TODO: Change how we do this to not assign to the log record
|
|
32
|
-
record.workflow_run_id = ctx_workflow_run_id.get()
|
|
33
|
-
record.step_run_id = ctx_step_run_id.get()
|
|
34
|
-
return True
|
|
52
|
+
class LogRecord(BaseModel):
|
|
53
|
+
message: str
|
|
54
|
+
step_run_id: str
|
|
35
55
|
|
|
36
56
|
|
|
37
|
-
class
|
|
38
|
-
def __init__(self, event_client: EventClient
|
|
39
|
-
super().__init__(stream)
|
|
40
|
-
self.logger_thread_pool = ThreadPoolExecutor(max_workers=1)
|
|
57
|
+
class AsyncLogSender:
|
|
58
|
+
def __init__(self, event_client: EventClient):
|
|
41
59
|
self.event_client = event_client
|
|
60
|
+
self.q = asyncio.Queue[LogRecord | STOP_LOOP_TYPE](maxsize=1000)
|
|
61
|
+
|
|
62
|
+
async def consume(self) -> None:
|
|
63
|
+
while True:
|
|
64
|
+
record = await self.q.get()
|
|
65
|
+
|
|
66
|
+
if record == STOP_LOOP:
|
|
67
|
+
break
|
|
42
68
|
|
|
43
|
-
|
|
69
|
+
try:
|
|
70
|
+
self.event_client.log(
|
|
71
|
+
message=record.message, step_run_id=record.step_run_id
|
|
72
|
+
)
|
|
73
|
+
except Exception as e:
|
|
74
|
+
logger.error(f"Error logging: {e}")
|
|
75
|
+
|
|
76
|
+
def publish(self, record: LogRecord | STOP_LOOP_TYPE) -> None:
|
|
44
77
|
try:
|
|
45
|
-
|
|
46
|
-
|
|
78
|
+
self.q.put_nowait(record)
|
|
79
|
+
except asyncio.QueueFull:
|
|
80
|
+
logger.warning("Log queue is full, dropping log message")
|
|
81
|
+
|
|
47
82
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
83
|
+
class CustomLogHandler(logging.StreamHandler): # type: ignore[type-arg]
|
|
84
|
+
def __init__(self, log_sender: AsyncLogSender, stream: StringIO):
|
|
85
|
+
super().__init__(stream)
|
|
86
|
+
|
|
87
|
+
self.log_sender = log_sender
|
|
51
88
|
|
|
52
89
|
def emit(self, record: logging.LogRecord) -> None:
|
|
53
90
|
super().emit(record)
|
|
54
91
|
|
|
55
92
|
log_entry = self.format(record)
|
|
93
|
+
step_run_id = ctx_step_run_id.get()
|
|
94
|
+
|
|
95
|
+
if not step_run_id:
|
|
96
|
+
return
|
|
56
97
|
|
|
57
|
-
|
|
58
|
-
self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) # type: ignore
|
|
98
|
+
self.log_sender.publish(LogRecord(message=log_entry, step_run_id=step_run_id))
|
|
59
99
|
|
|
60
100
|
|
|
61
101
|
def capture_logs(
|
|
62
|
-
logger: logging.Logger,
|
|
102
|
+
logger: logging.Logger, log_sender: AsyncLogSender, func: Callable[P, Awaitable[T]]
|
|
63
103
|
) -> Callable[P, Awaitable[T]]:
|
|
64
104
|
@functools.wraps(func)
|
|
65
105
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
|
|
66
106
|
log_stream = StringIO()
|
|
67
|
-
custom_handler = CustomLogHandler(
|
|
107
|
+
custom_handler = CustomLogHandler(log_sender, log_stream)
|
|
68
108
|
custom_handler.setLevel(logging.INFO)
|
|
69
|
-
|
|
70
|
-
logger.
|
|
109
|
+
|
|
110
|
+
if not any(h for h in logger.handlers if isinstance(h, CustomLogHandler)):
|
|
111
|
+
logger.addHandler(custom_handler)
|
|
71
112
|
|
|
72
113
|
try:
|
|
73
114
|
result = await func(*args, **kwargs)
|