hatchet-sdk 1.15.2__py3-none-any.whl → 1.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hatchet-sdk might be problematic. Click here for more details.
- hatchet_sdk/clients/admin.py +3 -1
- hatchet_sdk/clients/dispatcher/action_listener.py +13 -13
- hatchet_sdk/clients/event_ts.py +1 -1
- hatchet_sdk/clients/listeners/pooled_listener.py +4 -4
- hatchet_sdk/clients/rest/tenacity_utils.py +1 -1
- hatchet_sdk/clients/v1/api_client.py +3 -3
- hatchet_sdk/context/context.py +6 -6
- hatchet_sdk/features/runs.py +21 -1
- hatchet_sdk/opentelemetry/instrumentor.py +3 -3
- hatchet_sdk/runnables/contextvars.py +4 -0
- hatchet_sdk/runnables/task.py +133 -0
- hatchet_sdk/runnables/workflow.py +218 -11
- hatchet_sdk/utils/serde.py +8 -10
- hatchet_sdk/worker/action_listener_process.py +11 -11
- hatchet_sdk/worker/runner/runner.py +39 -21
- hatchet_sdk/worker/runner/utils/capture_logs.py +30 -15
- hatchet_sdk/worker/worker.py +11 -14
- {hatchet_sdk-1.15.2.dist-info → hatchet_sdk-1.16.0.dist-info}/METADATA +1 -1
- {hatchet_sdk-1.15.2.dist-info → hatchet_sdk-1.16.0.dist-info}/RECORD +21 -21
- {hatchet_sdk-1.15.2.dist-info → hatchet_sdk-1.16.0.dist-info}/WHEEL +0 -0
- {hatchet_sdk-1.15.2.dist-info → hatchet_sdk-1.16.0.dist-info}/entry_points.txt +0 -0
|
@@ -2,7 +2,16 @@ import asyncio
|
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
from datetime import datetime, timedelta
|
|
4
4
|
from functools import cached_property
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import (
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Any,
|
|
8
|
+
Generic,
|
|
9
|
+
Literal,
|
|
10
|
+
TypeVar,
|
|
11
|
+
cast,
|
|
12
|
+
get_type_hints,
|
|
13
|
+
overload,
|
|
14
|
+
)
|
|
6
15
|
|
|
7
16
|
from google.protobuf import timestamp_pb2
|
|
8
17
|
from pydantic import BaseModel, model_validator
|
|
@@ -651,39 +660,83 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
|
|
651
660
|
|
|
652
661
|
return await ref.aio_result()
|
|
653
662
|
|
|
663
|
+
def _get_result(
|
|
664
|
+
self, ref: WorkflowRunRef, return_exceptions: bool
|
|
665
|
+
) -> dict[str, Any] | BaseException:
|
|
666
|
+
try:
|
|
667
|
+
return ref.result()
|
|
668
|
+
except Exception as e:
|
|
669
|
+
if return_exceptions:
|
|
670
|
+
return e
|
|
671
|
+
raise e
|
|
672
|
+
|
|
673
|
+
@overload
|
|
674
|
+
def run_many(
|
|
675
|
+
self,
|
|
676
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
677
|
+
return_exceptions: Literal[True],
|
|
678
|
+
) -> list[dict[str, Any] | BaseException]: ...
|
|
679
|
+
|
|
680
|
+
@overload
|
|
654
681
|
def run_many(
|
|
655
682
|
self,
|
|
656
683
|
workflows: list[WorkflowRunTriggerConfig],
|
|
657
|
-
|
|
684
|
+
return_exceptions: Literal[False] = False,
|
|
685
|
+
) -> list[dict[str, Any]]: ...
|
|
686
|
+
|
|
687
|
+
def run_many(
|
|
688
|
+
self,
|
|
689
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
690
|
+
return_exceptions: bool = False,
|
|
691
|
+
) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
|
|
658
692
|
"""
|
|
659
693
|
Run a workflow in bulk and wait for all runs to complete.
|
|
660
694
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
661
695
|
|
|
662
696
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
697
|
+
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
663
698
|
:returns: A list of results for each workflow run.
|
|
664
699
|
"""
|
|
665
700
|
refs = self.client._client.admin.run_workflows(
|
|
666
701
|
workflows=workflows,
|
|
667
702
|
)
|
|
668
703
|
|
|
669
|
-
return [
|
|
704
|
+
return [self._get_result(ref, return_exceptions) for ref in refs]
|
|
670
705
|
|
|
706
|
+
@overload
|
|
671
707
|
async def aio_run_many(
|
|
672
708
|
self,
|
|
673
709
|
workflows: list[WorkflowRunTriggerConfig],
|
|
674
|
-
|
|
710
|
+
return_exceptions: Literal[True],
|
|
711
|
+
) -> list[dict[str, Any] | BaseException]: ...
|
|
712
|
+
|
|
713
|
+
@overload
|
|
714
|
+
async def aio_run_many(
|
|
715
|
+
self,
|
|
716
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
717
|
+
return_exceptions: Literal[False] = False,
|
|
718
|
+
) -> list[dict[str, Any]]: ...
|
|
719
|
+
|
|
720
|
+
async def aio_run_many(
|
|
721
|
+
self,
|
|
722
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
723
|
+
return_exceptions: bool = False,
|
|
724
|
+
) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
|
|
675
725
|
"""
|
|
676
726
|
Run a workflow in bulk and wait for all runs to complete.
|
|
677
727
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
678
728
|
|
|
679
729
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
730
|
+
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
680
731
|
:returns: A list of results for each workflow run.
|
|
681
732
|
"""
|
|
682
733
|
refs = await self.client._client.admin.aio_run_workflows(
|
|
683
734
|
workflows=workflows,
|
|
684
735
|
)
|
|
685
736
|
|
|
686
|
-
return await asyncio.gather(
|
|
737
|
+
return await asyncio.gather(
|
|
738
|
+
*[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
|
|
739
|
+
)
|
|
687
740
|
|
|
688
741
|
def run_many_no_wait(
|
|
689
742
|
self,
|
|
@@ -946,7 +999,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
|
|
946
999
|
|
|
947
1000
|
:param backoff_max_seconds: The maximum number of seconds to allow retries with exponential backoff to continue.
|
|
948
1001
|
|
|
949
|
-
:param concurrency: A list of concurrency expressions for the on-
|
|
1002
|
+
:param concurrency: A list of concurrency expressions for the on-failure task.
|
|
950
1003
|
|
|
951
1004
|
:returns: A decorator which creates a `Task` object.
|
|
952
1005
|
"""
|
|
@@ -1137,7 +1190,18 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
|
|
1137
1190
|
|
|
1138
1191
|
self.config = self._workflow.config
|
|
1139
1192
|
|
|
1140
|
-
|
|
1193
|
+
@overload
|
|
1194
|
+
def _extract_result(self, result: dict[str, Any]) -> R: ...
|
|
1195
|
+
|
|
1196
|
+
@overload
|
|
1197
|
+
def _extract_result(self, result: BaseException) -> BaseException: ...
|
|
1198
|
+
|
|
1199
|
+
def _extract_result(
|
|
1200
|
+
self, result: dict[str, Any] | BaseException
|
|
1201
|
+
) -> R | BaseException:
|
|
1202
|
+
if isinstance(result, BaseException):
|
|
1203
|
+
return result
|
|
1204
|
+
|
|
1141
1205
|
output = result.get(self._task.name)
|
|
1142
1206
|
|
|
1143
1207
|
if not self._output_validator:
|
|
@@ -1217,30 +1281,72 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
|
|
1217
1281
|
|
|
1218
1282
|
return TaskRunRef[TWorkflowInput, R](self, ref)
|
|
1219
1283
|
|
|
1220
|
-
|
|
1284
|
+
@overload
|
|
1285
|
+
def run_many(
|
|
1286
|
+
self,
|
|
1287
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
1288
|
+
return_exceptions: Literal[True],
|
|
1289
|
+
) -> list[R | BaseException]: ...
|
|
1290
|
+
|
|
1291
|
+
@overload
|
|
1292
|
+
def run_many(
|
|
1293
|
+
self,
|
|
1294
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
1295
|
+
return_exceptions: Literal[False] = False,
|
|
1296
|
+
) -> list[R]: ...
|
|
1297
|
+
|
|
1298
|
+
def run_many(
|
|
1299
|
+
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
|
|
1300
|
+
) -> list[R] | list[R | BaseException]:
|
|
1221
1301
|
"""
|
|
1222
1302
|
Run a workflow in bulk and wait for all runs to complete.
|
|
1223
1303
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
1224
1304
|
|
|
1225
1305
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
1306
|
+
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
1226
1307
|
:returns: A list of results for each workflow run.
|
|
1227
1308
|
"""
|
|
1228
1309
|
return [
|
|
1229
1310
|
self._extract_result(result)
|
|
1230
|
-
for result in self._workflow.run_many(
|
|
1311
|
+
for result in self._workflow.run_many(
|
|
1312
|
+
workflows,
|
|
1313
|
+
## hack: typing needs literal
|
|
1314
|
+
True if return_exceptions else False, # noqa: SIM210
|
|
1315
|
+
)
|
|
1231
1316
|
]
|
|
1232
1317
|
|
|
1233
|
-
|
|
1318
|
+
@overload
|
|
1319
|
+
async def aio_run_many(
|
|
1320
|
+
self,
|
|
1321
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
1322
|
+
return_exceptions: Literal[True],
|
|
1323
|
+
) -> list[R | BaseException]: ...
|
|
1324
|
+
|
|
1325
|
+
@overload
|
|
1326
|
+
async def aio_run_many(
|
|
1327
|
+
self,
|
|
1328
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
1329
|
+
return_exceptions: Literal[False] = False,
|
|
1330
|
+
) -> list[R]: ...
|
|
1331
|
+
|
|
1332
|
+
async def aio_run_many(
|
|
1333
|
+
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
|
|
1334
|
+
) -> list[R] | list[R | BaseException]:
|
|
1234
1335
|
"""
|
|
1235
1336
|
Run a workflow in bulk and wait for all runs to complete.
|
|
1236
1337
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
1237
1338
|
|
|
1238
1339
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
1340
|
+
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
1239
1341
|
:returns: A list of results for each workflow run.
|
|
1240
1342
|
"""
|
|
1241
1343
|
return [
|
|
1242
1344
|
self._extract_result(result)
|
|
1243
|
-
for result in await self._workflow.aio_run_many(
|
|
1345
|
+
for result in await self._workflow.aio_run_many(
|
|
1346
|
+
workflows,
|
|
1347
|
+
## hack: typing needs literal
|
|
1348
|
+
True if return_exceptions else False, # noqa: SIM210
|
|
1349
|
+
)
|
|
1244
1350
|
]
|
|
1245
1351
|
|
|
1246
1352
|
def run_many_no_wait(
|
|
@@ -1273,3 +1379,104 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
|
|
1273
1379
|
refs = await self._workflow.aio_run_many_no_wait(workflows)
|
|
1274
1380
|
|
|
1275
1381
|
return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
|
|
1382
|
+
|
|
1383
|
+
def mock_run(
|
|
1384
|
+
self,
|
|
1385
|
+
input: TWorkflowInput | None = None,
|
|
1386
|
+
additional_metadata: JSONSerializableMapping | None = None,
|
|
1387
|
+
parent_outputs: dict[str, JSONSerializableMapping] | None = None,
|
|
1388
|
+
retry_count: int = 0,
|
|
1389
|
+
lifespan: Any = None,
|
|
1390
|
+
) -> R:
|
|
1391
|
+
"""
|
|
1392
|
+
Mimic the execution of a task. This method is intended to be used to unit test
|
|
1393
|
+
tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
|
|
1394
|
+
tasks and `aio_mock_run` for async tasks.
|
|
1395
|
+
|
|
1396
|
+
:param input: The input to the task.
|
|
1397
|
+
:param additional_metadata: Additional metadata to attach to the task.
|
|
1398
|
+
:param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
|
|
1399
|
+
:param retry_count: The number of times the task has been retried.
|
|
1400
|
+
:param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
|
|
1401
|
+
|
|
1402
|
+
:return: The output of the task.
|
|
1403
|
+
"""
|
|
1404
|
+
|
|
1405
|
+
return self._task.mock_run(
|
|
1406
|
+
input=input,
|
|
1407
|
+
additional_metadata=additional_metadata,
|
|
1408
|
+
parent_outputs=parent_outputs,
|
|
1409
|
+
retry_count=retry_count,
|
|
1410
|
+
lifespan=lifespan,
|
|
1411
|
+
)
|
|
1412
|
+
|
|
1413
|
+
async def aio_mock_run(
|
|
1414
|
+
self,
|
|
1415
|
+
input: TWorkflowInput | None = None,
|
|
1416
|
+
additional_metadata: JSONSerializableMapping | None = None,
|
|
1417
|
+
parent_outputs: dict[str, JSONSerializableMapping] | None = None,
|
|
1418
|
+
retry_count: int = 0,
|
|
1419
|
+
lifespan: Any = None,
|
|
1420
|
+
) -> R:
|
|
1421
|
+
"""
|
|
1422
|
+
Mimic the execution of a task. This method is intended to be used to unit test
|
|
1423
|
+
tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
|
|
1424
|
+
tasks and `aio_mock_run` for async tasks.
|
|
1425
|
+
|
|
1426
|
+
:param input: The input to the task.
|
|
1427
|
+
:param additional_metadata: Additional metadata to attach to the task.
|
|
1428
|
+
:param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
|
|
1429
|
+
:param retry_count: The number of times the task has been retried.
|
|
1430
|
+
:param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
|
|
1431
|
+
|
|
1432
|
+
:return: The output of the task.
|
|
1433
|
+
"""
|
|
1434
|
+
|
|
1435
|
+
return await self._task.aio_mock_run(
|
|
1436
|
+
input=input,
|
|
1437
|
+
additional_metadata=additional_metadata,
|
|
1438
|
+
parent_outputs=parent_outputs,
|
|
1439
|
+
retry_count=retry_count,
|
|
1440
|
+
lifespan=lifespan,
|
|
1441
|
+
)
|
|
1442
|
+
|
|
1443
|
+
@property
|
|
1444
|
+
def is_async_function(self) -> bool:
|
|
1445
|
+
"""
|
|
1446
|
+
Check if the task is an async function.
|
|
1447
|
+
|
|
1448
|
+
:returns: True if the task is an async function, False otherwise.
|
|
1449
|
+
"""
|
|
1450
|
+
return self._task.is_async_function
|
|
1451
|
+
|
|
1452
|
+
def get_run_ref(self, run_id: str) -> TaskRunRef[TWorkflowInput, R]:
|
|
1453
|
+
"""
|
|
1454
|
+
Get a reference to a task run by its run ID.
|
|
1455
|
+
|
|
1456
|
+
:param run_id: The ID of the run to get the reference for.
|
|
1457
|
+
:returns: A `TaskRunRef` object representing the reference to the task run.
|
|
1458
|
+
"""
|
|
1459
|
+
wrr = self._workflow.client._client.runs.get_run_ref(run_id)
|
|
1460
|
+
return TaskRunRef[TWorkflowInput, R](self, wrr)
|
|
1461
|
+
|
|
1462
|
+
async def aio_get_result(self, run_id: str) -> R:
|
|
1463
|
+
"""
|
|
1464
|
+
Get the result of a task run by its run ID.
|
|
1465
|
+
|
|
1466
|
+
:param run_id: The ID of the run to get the result for.
|
|
1467
|
+
:returns: The result of the task run.
|
|
1468
|
+
"""
|
|
1469
|
+
run_ref = self.get_run_ref(run_id)
|
|
1470
|
+
|
|
1471
|
+
return await run_ref.aio_result()
|
|
1472
|
+
|
|
1473
|
+
def get_result(self, run_id: str) -> R:
|
|
1474
|
+
"""
|
|
1475
|
+
Get the result of a task run by its run ID.
|
|
1476
|
+
|
|
1477
|
+
:param run_id: The ID of the run to get the result for.
|
|
1478
|
+
:returns: The result of the task run.
|
|
1479
|
+
"""
|
|
1480
|
+
run_ref = self.get_run_ref(run_id)
|
|
1481
|
+
|
|
1482
|
+
return run_ref.result()
|
hatchet_sdk/utils/serde.py
CHANGED
|
@@ -2,10 +2,7 @@ from typing import Any, TypeVar, cast, overload
|
|
|
2
2
|
|
|
3
3
|
T = TypeVar("T")
|
|
4
4
|
K = TypeVar("K")
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
@overload
|
|
8
|
-
def remove_null_unicode_character(data: str, replacement: str = "") -> str: ...
|
|
5
|
+
R = TypeVar("R")
|
|
9
6
|
|
|
10
7
|
|
|
11
8
|
@overload
|
|
@@ -24,9 +21,13 @@ def remove_null_unicode_character(
|
|
|
24
21
|
) -> tuple[T, ...]: ...
|
|
25
22
|
|
|
26
23
|
|
|
24
|
+
@overload
|
|
25
|
+
def remove_null_unicode_character(data: R, replacement: str = "") -> R: ...
|
|
26
|
+
|
|
27
|
+
|
|
27
28
|
def remove_null_unicode_character(
|
|
28
|
-
data:
|
|
29
|
-
) -> str | dict[K, T] | list[T] | tuple[T, ...]:
|
|
29
|
+
data: dict[K, T] | list[T] | tuple[T, ...] | R, replacement: str = ""
|
|
30
|
+
) -> str | dict[K, T] | list[T] | tuple[T, ...] | R:
|
|
30
31
|
"""
|
|
31
32
|
Recursively traverse a dictionary (a task's output) and remove the unicode escape sequence \\u0000 which will cause unexpected behavior in Hatchet.
|
|
32
33
|
|
|
@@ -36,7 +37,6 @@ def remove_null_unicode_character(
|
|
|
36
37
|
:param replacement: The string to replace \\u0000 with.
|
|
37
38
|
|
|
38
39
|
:return: The same dictionary with all \\u0000 characters removed from strings, and nested dictionaries/lists processed recursively.
|
|
39
|
-
:raises TypeError: If the input is not a string, dictionary, list, or tuple.
|
|
40
40
|
"""
|
|
41
41
|
if isinstance(data, str):
|
|
42
42
|
return data.replace("\u0000", replacement)
|
|
@@ -57,6 +57,4 @@ def remove_null_unicode_character(
|
|
|
57
57
|
remove_null_unicode_character(cast(Any, item), replacement) for item in data
|
|
58
58
|
)
|
|
59
59
|
|
|
60
|
-
|
|
61
|
-
f"Unsupported type {type(data)}. Expected str, dict, list, or tuple."
|
|
62
|
-
)
|
|
60
|
+
return data
|
|
@@ -132,8 +132,8 @@ class WorkerActionListenerProcess:
|
|
|
132
132
|
)
|
|
133
133
|
|
|
134
134
|
logger.debug(f"acquired action listener: {self.listener.worker_id}")
|
|
135
|
-
except grpc.RpcError
|
|
136
|
-
logger.
|
|
135
|
+
except grpc.RpcError:
|
|
136
|
+
logger.exception("could not start action listener")
|
|
137
137
|
return
|
|
138
138
|
|
|
139
139
|
# Start both loops as background tasks
|
|
@@ -168,7 +168,7 @@ class WorkerActionListenerProcess:
|
|
|
168
168
|
count += 1
|
|
169
169
|
|
|
170
170
|
if count > 0:
|
|
171
|
-
logger.warning(f"{BLOCKED_THREAD_WARNING}
|
|
171
|
+
logger.warning(f"{BLOCKED_THREAD_WARNING} Waiting Steps {count}")
|
|
172
172
|
await asyncio.sleep(1)
|
|
173
173
|
|
|
174
174
|
async def send_event(self, event: ActionEvent, retry_attempt: int = 1) -> None:
|
|
@@ -188,7 +188,7 @@ class WorkerActionListenerProcess:
|
|
|
188
188
|
)
|
|
189
189
|
if diff > 0.1:
|
|
190
190
|
logger.warning(
|
|
191
|
-
f"{BLOCKED_THREAD_WARNING}
|
|
191
|
+
f"{BLOCKED_THREAD_WARNING} time to start: {diff}s"
|
|
192
192
|
)
|
|
193
193
|
else:
|
|
194
194
|
logger.debug(f"start time: {diff}")
|
|
@@ -225,9 +225,9 @@ class WorkerActionListenerProcess:
|
|
|
225
225
|
)
|
|
226
226
|
case _:
|
|
227
227
|
logger.error("unknown action type for event send")
|
|
228
|
-
except Exception
|
|
229
|
-
logger.
|
|
230
|
-
f"could not send action event ({retry_attempt}/{ACTION_EVENT_RETRY_COUNT})
|
|
228
|
+
except Exception:
|
|
229
|
+
logger.exception(
|
|
230
|
+
f"could not send action event ({retry_attempt}/{ACTION_EVENT_RETRY_COUNT})"
|
|
231
231
|
)
|
|
232
232
|
if retry_attempt <= ACTION_EVENT_RETRY_COUNT:
|
|
233
233
|
await exp_backoff_sleep(retry_attempt, 1)
|
|
@@ -291,11 +291,11 @@ class WorkerActionListenerProcess:
|
|
|
291
291
|
)
|
|
292
292
|
try:
|
|
293
293
|
self.action_queue.put(action)
|
|
294
|
-
except Exception
|
|
295
|
-
logger.
|
|
294
|
+
except Exception:
|
|
295
|
+
logger.exception("error putting action")
|
|
296
296
|
|
|
297
|
-
except Exception
|
|
298
|
-
logger.
|
|
297
|
+
except Exception:
|
|
298
|
+
logger.exception("error in action loop")
|
|
299
299
|
finally:
|
|
300
300
|
logger.info("action loop closed")
|
|
301
301
|
if not self.killing:
|
|
@@ -40,6 +40,7 @@ from hatchet_sdk.logger import logger
|
|
|
40
40
|
from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
|
|
41
41
|
from hatchet_sdk.runnables.contextvars import (
|
|
42
42
|
ctx_action_key,
|
|
43
|
+
ctx_additional_metadata,
|
|
43
44
|
ctx_step_run_id,
|
|
44
45
|
ctx_worker_id,
|
|
45
46
|
ctx_workflow_run_id,
|
|
@@ -54,6 +55,8 @@ from hatchet_sdk.worker.action_listener_process import ActionEvent
|
|
|
54
55
|
from hatchet_sdk.worker.runner.utils.capture_logs import (
|
|
55
56
|
AsyncLogSender,
|
|
56
57
|
ContextVarToCopy,
|
|
58
|
+
ContextVarToCopyDict,
|
|
59
|
+
ContextVarToCopyStr,
|
|
57
60
|
copy_context_vars,
|
|
58
61
|
)
|
|
59
62
|
|
|
@@ -295,6 +298,7 @@ class Runner:
|
|
|
295
298
|
ctx_workflow_run_id.set(action.workflow_run_id)
|
|
296
299
|
ctx_worker_id.set(action.worker_id)
|
|
297
300
|
ctx_action_key.set(action.key)
|
|
301
|
+
ctx_additional_metadata.set(action.additional_metadata)
|
|
298
302
|
|
|
299
303
|
try:
|
|
300
304
|
if task.is_async_function:
|
|
@@ -305,20 +309,34 @@ class Runner:
|
|
|
305
309
|
copy_context_vars,
|
|
306
310
|
[
|
|
307
311
|
ContextVarToCopy(
|
|
308
|
-
|
|
309
|
-
|
|
312
|
+
var=ContextVarToCopyStr(
|
|
313
|
+
name="ctx_step_run_id",
|
|
314
|
+
value=action.step_run_id,
|
|
315
|
+
)
|
|
316
|
+
),
|
|
317
|
+
ContextVarToCopy(
|
|
318
|
+
var=ContextVarToCopyStr(
|
|
319
|
+
name="ctx_workflow_run_id",
|
|
320
|
+
value=action.workflow_run_id,
|
|
321
|
+
)
|
|
310
322
|
),
|
|
311
323
|
ContextVarToCopy(
|
|
312
|
-
|
|
313
|
-
|
|
324
|
+
var=ContextVarToCopyStr(
|
|
325
|
+
name="ctx_worker_id",
|
|
326
|
+
value=action.worker_id,
|
|
327
|
+
)
|
|
314
328
|
),
|
|
315
329
|
ContextVarToCopy(
|
|
316
|
-
|
|
317
|
-
|
|
330
|
+
var=ContextVarToCopyStr(
|
|
331
|
+
name="ctx_action_key",
|
|
332
|
+
value=action.key,
|
|
333
|
+
)
|
|
318
334
|
),
|
|
319
335
|
ContextVarToCopy(
|
|
320
|
-
|
|
321
|
-
|
|
336
|
+
var=ContextVarToCopyDict(
|
|
337
|
+
name="ctx_additional_metadata",
|
|
338
|
+
value=action.additional_metadata,
|
|
339
|
+
)
|
|
322
340
|
),
|
|
323
341
|
],
|
|
324
342
|
self.thread_action_func,
|
|
@@ -344,34 +362,34 @@ class Runner:
|
|
|
344
362
|
"threads_daemon": sum(1 for t in self.thread_pool._threads if t.daemon),
|
|
345
363
|
}
|
|
346
364
|
|
|
347
|
-
logger.warning("
|
|
365
|
+
logger.warning("thread pool detailed status %s", thread_pool_details)
|
|
348
366
|
|
|
349
367
|
async def _start_monitoring(self) -> None:
|
|
350
|
-
logger.debug("
|
|
368
|
+
logger.debug("thread pool monitoring started")
|
|
351
369
|
try:
|
|
352
370
|
while True:
|
|
353
371
|
await self.log_thread_pool_status()
|
|
354
372
|
|
|
355
373
|
for key in self.threads:
|
|
356
374
|
if key not in self.tasks:
|
|
357
|
-
logger.debug(f"
|
|
375
|
+
logger.debug(f"potential zombie thread found for key {key}")
|
|
358
376
|
|
|
359
377
|
for key, task in self.tasks.items():
|
|
360
378
|
if task.done() and key in self.threads:
|
|
361
379
|
logger.debug(
|
|
362
|
-
f"
|
|
380
|
+
f"task is done but thread still exists for key {key}"
|
|
363
381
|
)
|
|
364
382
|
|
|
365
383
|
await asyncio.sleep(60)
|
|
366
384
|
except asyncio.CancelledError:
|
|
367
|
-
logger.warning("
|
|
385
|
+
logger.warning("thread pool monitoring task cancelled")
|
|
368
386
|
except Exception as e:
|
|
369
|
-
logger.exception(f"
|
|
387
|
+
logger.exception(f"error in thread pool monitoring: {e}")
|
|
370
388
|
|
|
371
389
|
def start_background_monitoring(self) -> None:
|
|
372
390
|
loop = asyncio.get_event_loop()
|
|
373
391
|
self.monitoring_task = loop.create_task(self._start_monitoring())
|
|
374
|
-
logger.debug("
|
|
392
|
+
logger.debug("started thread pool monitoring background task")
|
|
375
393
|
|
|
376
394
|
def cleanup_run_id(self, key: ActionKey) -> None:
|
|
377
395
|
if key in self.tasks:
|
|
@@ -503,7 +521,7 @@ class Runner:
|
|
|
503
521
|
|
|
504
522
|
ident = cast(int, thread.ident)
|
|
505
523
|
|
|
506
|
-
logger.info(f"
|
|
524
|
+
logger.info(f"forcefully terminating thread {ident}")
|
|
507
525
|
|
|
508
526
|
exc = ctypes.py_object(SystemExit)
|
|
509
527
|
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
|
|
@@ -516,13 +534,13 @@ class Runner:
|
|
|
516
534
|
ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
|
|
517
535
|
raise SystemError("PyThreadState_SetAsyncExc failed")
|
|
518
536
|
|
|
519
|
-
logger.info(f"
|
|
537
|
+
logger.info(f"successfully terminated thread {ident}")
|
|
520
538
|
|
|
521
539
|
# Immediately add a new thread to the thread pool, because we've actually killed a worker
|
|
522
540
|
# in the ThreadPoolExecutor
|
|
523
541
|
self.thread_pool.submit(lambda: None)
|
|
524
542
|
except Exception as e:
|
|
525
|
-
logger.exception(f"
|
|
543
|
+
logger.exception(f"failed to terminate thread: {e}")
|
|
526
544
|
|
|
527
545
|
## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
|
|
528
546
|
async def handle_cancel_action(self, action: Action) -> None:
|
|
@@ -546,7 +564,7 @@ class Runner:
|
|
|
546
564
|
await asyncio.sleep(1)
|
|
547
565
|
|
|
548
566
|
logger.warning(
|
|
549
|
-
f"
|
|
567
|
+
f"thread {self.threads[key].ident} with key {key} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
|
|
550
568
|
)
|
|
551
569
|
finally:
|
|
552
570
|
self.cleanup_run_id(key)
|
|
@@ -568,8 +586,8 @@ class Runner:
|
|
|
568
586
|
|
|
569
587
|
try:
|
|
570
588
|
serialized_output = json.dumps(output, default=str)
|
|
571
|
-
except Exception
|
|
572
|
-
logger.
|
|
589
|
+
except Exception:
|
|
590
|
+
logger.exception("could not serialize output")
|
|
573
591
|
serialized_output = str(output)
|
|
574
592
|
|
|
575
593
|
if "\\u0000" in serialized_output:
|
|
@@ -5,29 +5,42 @@ from collections.abc import Awaitable, Callable
|
|
|
5
5
|
from io import StringIO
|
|
6
6
|
from typing import Literal, ParamSpec, TypeVar
|
|
7
7
|
|
|
8
|
-
from pydantic import BaseModel
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
9
|
|
|
10
10
|
from hatchet_sdk.clients.events import EventClient
|
|
11
11
|
from hatchet_sdk.logger import logger
|
|
12
12
|
from hatchet_sdk.runnables.contextvars import (
|
|
13
13
|
ctx_action_key,
|
|
14
|
+
ctx_additional_metadata,
|
|
14
15
|
ctx_step_run_id,
|
|
15
16
|
ctx_worker_id,
|
|
16
17
|
ctx_workflow_run_id,
|
|
17
18
|
)
|
|
18
|
-
from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE
|
|
19
|
+
from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE, JSONSerializableMapping
|
|
19
20
|
|
|
20
21
|
T = TypeVar("T")
|
|
21
22
|
P = ParamSpec("P")
|
|
22
23
|
|
|
23
24
|
|
|
24
|
-
class
|
|
25
|
+
class ContextVarToCopyStr(BaseModel):
|
|
25
26
|
name: Literal[
|
|
26
|
-
"ctx_workflow_run_id",
|
|
27
|
+
"ctx_workflow_run_id",
|
|
28
|
+
"ctx_step_run_id",
|
|
29
|
+
"ctx_action_key",
|
|
30
|
+
"ctx_worker_id",
|
|
27
31
|
]
|
|
28
32
|
value: str | None
|
|
29
33
|
|
|
30
34
|
|
|
35
|
+
class ContextVarToCopyDict(BaseModel):
|
|
36
|
+
name: Literal["ctx_additional_metadata"]
|
|
37
|
+
value: JSONSerializableMapping | None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ContextVarToCopy(BaseModel):
|
|
41
|
+
var: ContextVarToCopyStr | ContextVarToCopyDict = Field(discriminator="name")
|
|
42
|
+
|
|
43
|
+
|
|
31
44
|
def copy_context_vars(
|
|
32
45
|
ctx_vars: list[ContextVarToCopy],
|
|
33
46
|
func: Callable[P, T],
|
|
@@ -35,16 +48,18 @@ def copy_context_vars(
|
|
|
35
48
|
**kwargs: P.kwargs,
|
|
36
49
|
) -> T:
|
|
37
50
|
for var in ctx_vars:
|
|
38
|
-
if var.name == "ctx_workflow_run_id":
|
|
39
|
-
ctx_workflow_run_id.set(var.value)
|
|
40
|
-
elif var.name == "ctx_step_run_id":
|
|
41
|
-
ctx_step_run_id.set(var.value)
|
|
42
|
-
elif var.name == "ctx_action_key":
|
|
43
|
-
ctx_action_key.set(var.value)
|
|
44
|
-
elif var.name == "ctx_worker_id":
|
|
45
|
-
ctx_worker_id.set(var.value)
|
|
51
|
+
if var.var.name == "ctx_workflow_run_id":
|
|
52
|
+
ctx_workflow_run_id.set(var.var.value)
|
|
53
|
+
elif var.var.name == "ctx_step_run_id":
|
|
54
|
+
ctx_step_run_id.set(var.var.value)
|
|
55
|
+
elif var.var.name == "ctx_action_key":
|
|
56
|
+
ctx_action_key.set(var.var.value)
|
|
57
|
+
elif var.var.name == "ctx_worker_id":
|
|
58
|
+
ctx_worker_id.set(var.var.value)
|
|
59
|
+
elif var.var.name == "ctx_additional_metadata":
|
|
60
|
+
ctx_additional_metadata.set(var.var.value or {})
|
|
46
61
|
else:
|
|
47
|
-
raise ValueError(f"Unknown context variable name: {var.name}")
|
|
62
|
+
raise ValueError(f"Unknown context variable name: {var.var.name}")
|
|
48
63
|
|
|
49
64
|
return func(*args, **kwargs)
|
|
50
65
|
|
|
@@ -73,13 +88,13 @@ class AsyncLogSender:
|
|
|
73
88
|
step_run_id=record.step_run_id,
|
|
74
89
|
)
|
|
75
90
|
except Exception:
|
|
76
|
-
logger.exception("
|
|
91
|
+
logger.exception("failed to send log to Hatchet")
|
|
77
92
|
|
|
78
93
|
def publish(self, record: LogRecord | STOP_LOOP_TYPE) -> None:
|
|
79
94
|
try:
|
|
80
95
|
self.q.put_nowait(record)
|
|
81
96
|
except asyncio.QueueFull:
|
|
82
|
-
logger.warning("
|
|
97
|
+
logger.warning("log queue is full, dropping log message")
|
|
83
98
|
|
|
84
99
|
|
|
85
100
|
class CustomLogHandler(logging.StreamHandler): # type: ignore[type-arg]
|