hatchet-sdk 1.15.2__py3-none-any.whl → 1.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

@@ -2,7 +2,16 @@ import asyncio
2
2
  from collections.abc import Callable
3
3
  from datetime import datetime, timedelta
4
4
  from functools import cached_property
5
- from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_type_hints
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Generic,
9
+ Literal,
10
+ TypeVar,
11
+ cast,
12
+ get_type_hints,
13
+ overload,
14
+ )
6
15
 
7
16
  from google.protobuf import timestamp_pb2
8
17
  from pydantic import BaseModel, model_validator
@@ -651,39 +660,83 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
651
660
 
652
661
  return await ref.aio_result()
653
662
 
663
+ def _get_result(
664
+ self, ref: WorkflowRunRef, return_exceptions: bool
665
+ ) -> dict[str, Any] | BaseException:
666
+ try:
667
+ return ref.result()
668
+ except Exception as e:
669
+ if return_exceptions:
670
+ return e
671
+ raise e
672
+
673
+ @overload
674
+ def run_many(
675
+ self,
676
+ workflows: list[WorkflowRunTriggerConfig],
677
+ return_exceptions: Literal[True],
678
+ ) -> list[dict[str, Any] | BaseException]: ...
679
+
680
+ @overload
654
681
  def run_many(
655
682
  self,
656
683
  workflows: list[WorkflowRunTriggerConfig],
657
- ) -> list[dict[str, Any]]:
684
+ return_exceptions: Literal[False] = False,
685
+ ) -> list[dict[str, Any]]: ...
686
+
687
+ def run_many(
688
+ self,
689
+ workflows: list[WorkflowRunTriggerConfig],
690
+ return_exceptions: bool = False,
691
+ ) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
658
692
  """
659
693
  Run a workflow in bulk and wait for all runs to complete.
660
694
  This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
661
695
 
662
696
  :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
697
+ :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
663
698
  :returns: A list of results for each workflow run.
664
699
  """
665
700
  refs = self.client._client.admin.run_workflows(
666
701
  workflows=workflows,
667
702
  )
668
703
 
669
- return [ref.result() for ref in refs]
704
+ return [self._get_result(ref, return_exceptions) for ref in refs]
670
705
 
706
+ @overload
671
707
  async def aio_run_many(
672
708
  self,
673
709
  workflows: list[WorkflowRunTriggerConfig],
674
- ) -> list[dict[str, Any]]:
710
+ return_exceptions: Literal[True],
711
+ ) -> list[dict[str, Any] | BaseException]: ...
712
+
713
+ @overload
714
+ async def aio_run_many(
715
+ self,
716
+ workflows: list[WorkflowRunTriggerConfig],
717
+ return_exceptions: Literal[False] = False,
718
+ ) -> list[dict[str, Any]]: ...
719
+
720
+ async def aio_run_many(
721
+ self,
722
+ workflows: list[WorkflowRunTriggerConfig],
723
+ return_exceptions: bool = False,
724
+ ) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
675
725
  """
676
726
  Run a workflow in bulk and wait for all runs to complete.
677
727
  This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
678
728
 
679
729
  :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
730
+ :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
680
731
  :returns: A list of results for each workflow run.
681
732
  """
682
733
  refs = await self.client._client.admin.aio_run_workflows(
683
734
  workflows=workflows,
684
735
  )
685
736
 
686
- return await asyncio.gather(*[ref.aio_result() for ref in refs])
737
+ return await asyncio.gather(
738
+ *[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
739
+ )
687
740
 
688
741
  def run_many_no_wait(
689
742
  self,
@@ -946,7 +999,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
946
999
 
947
1000
  :param backoff_max_seconds: The maximum number of seconds to allow retries with exponential backoff to continue.
948
1001
 
949
- :param concurrency: A list of concurrency expressions for the on-success task.
1002
+ :param concurrency: A list of concurrency expressions for the on-failure task.
950
1003
 
951
1004
  :returns: A decorator which creates a `Task` object.
952
1005
  """
@@ -1137,7 +1190,18 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
1137
1190
 
1138
1191
  self.config = self._workflow.config
1139
1192
 
1140
- def _extract_result(self, result: dict[str, Any]) -> R:
1193
+ @overload
1194
+ def _extract_result(self, result: dict[str, Any]) -> R: ...
1195
+
1196
+ @overload
1197
+ def _extract_result(self, result: BaseException) -> BaseException: ...
1198
+
1199
+ def _extract_result(
1200
+ self, result: dict[str, Any] | BaseException
1201
+ ) -> R | BaseException:
1202
+ if isinstance(result, BaseException):
1203
+ return result
1204
+
1141
1205
  output = result.get(self._task.name)
1142
1206
 
1143
1207
  if not self._output_validator:
@@ -1217,30 +1281,72 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
1217
1281
 
1218
1282
  return TaskRunRef[TWorkflowInput, R](self, ref)
1219
1283
 
1220
- def run_many(self, workflows: list[WorkflowRunTriggerConfig]) -> list[R]:
1284
+ @overload
1285
+ def run_many(
1286
+ self,
1287
+ workflows: list[WorkflowRunTriggerConfig],
1288
+ return_exceptions: Literal[True],
1289
+ ) -> list[R | BaseException]: ...
1290
+
1291
+ @overload
1292
+ def run_many(
1293
+ self,
1294
+ workflows: list[WorkflowRunTriggerConfig],
1295
+ return_exceptions: Literal[False] = False,
1296
+ ) -> list[R]: ...
1297
+
1298
+ def run_many(
1299
+ self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
1300
+ ) -> list[R] | list[R | BaseException]:
1221
1301
  """
1222
1302
  Run a workflow in bulk and wait for all runs to complete.
1223
1303
  This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
1224
1304
 
1225
1305
  :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
1306
+ :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
1226
1307
  :returns: A list of results for each workflow run.
1227
1308
  """
1228
1309
  return [
1229
1310
  self._extract_result(result)
1230
- for result in self._workflow.run_many(workflows)
1311
+ for result in self._workflow.run_many(
1312
+ workflows,
1313
+ ## hack: typing needs literal
1314
+ True if return_exceptions else False, # noqa: SIM210
1315
+ )
1231
1316
  ]
1232
1317
 
1233
- async def aio_run_many(self, workflows: list[WorkflowRunTriggerConfig]) -> list[R]:
1318
+ @overload
1319
+ async def aio_run_many(
1320
+ self,
1321
+ workflows: list[WorkflowRunTriggerConfig],
1322
+ return_exceptions: Literal[True],
1323
+ ) -> list[R | BaseException]: ...
1324
+
1325
+ @overload
1326
+ async def aio_run_many(
1327
+ self,
1328
+ workflows: list[WorkflowRunTriggerConfig],
1329
+ return_exceptions: Literal[False] = False,
1330
+ ) -> list[R]: ...
1331
+
1332
+ async def aio_run_many(
1333
+ self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
1334
+ ) -> list[R] | list[R | BaseException]:
1234
1335
  """
1235
1336
  Run a workflow in bulk and wait for all runs to complete.
1236
1337
  This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
1237
1338
 
1238
1339
  :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
1340
+ :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
1239
1341
  :returns: A list of results for each workflow run.
1240
1342
  """
1241
1343
  return [
1242
1344
  self._extract_result(result)
1243
- for result in await self._workflow.aio_run_many(workflows)
1345
+ for result in await self._workflow.aio_run_many(
1346
+ workflows,
1347
+ ## hack: typing needs literal
1348
+ True if return_exceptions else False, # noqa: SIM210
1349
+ )
1244
1350
  ]
1245
1351
 
1246
1352
  def run_many_no_wait(
@@ -1273,3 +1379,104 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
1273
1379
  refs = await self._workflow.aio_run_many_no_wait(workflows)
1274
1380
 
1275
1381
  return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
1382
+
1383
+ def mock_run(
1384
+ self,
1385
+ input: TWorkflowInput | None = None,
1386
+ additional_metadata: JSONSerializableMapping | None = None,
1387
+ parent_outputs: dict[str, JSONSerializableMapping] | None = None,
1388
+ retry_count: int = 0,
1389
+ lifespan: Any = None,
1390
+ ) -> R:
1391
+ """
1392
+ Mimic the execution of a task. This method is intended to be used to unit test
1393
+ tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
1394
+ tasks and `aio_mock_run` for async tasks.
1395
+
1396
+ :param input: The input to the task.
1397
+ :param additional_metadata: Additional metadata to attach to the task.
1398
+ :param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
1399
+ :param retry_count: The number of times the task has been retried.
1400
+ :param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
1401
+
1402
+ :return: The output of the task.
1403
+ """
1404
+
1405
+ return self._task.mock_run(
1406
+ input=input,
1407
+ additional_metadata=additional_metadata,
1408
+ parent_outputs=parent_outputs,
1409
+ retry_count=retry_count,
1410
+ lifespan=lifespan,
1411
+ )
1412
+
1413
+ async def aio_mock_run(
1414
+ self,
1415
+ input: TWorkflowInput | None = None,
1416
+ additional_metadata: JSONSerializableMapping | None = None,
1417
+ parent_outputs: dict[str, JSONSerializableMapping] | None = None,
1418
+ retry_count: int = 0,
1419
+ lifespan: Any = None,
1420
+ ) -> R:
1421
+ """
1422
+ Mimic the execution of a task. This method is intended to be used to unit test
1423
+ tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
1424
+ tasks and `aio_mock_run` for async tasks.
1425
+
1426
+ :param input: The input to the task.
1427
+ :param additional_metadata: Additional metadata to attach to the task.
1428
+ :param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
1429
+ :param retry_count: The number of times the task has been retried.
1430
+ :param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
1431
+
1432
+ :return: The output of the task.
1433
+ """
1434
+
1435
+ return await self._task.aio_mock_run(
1436
+ input=input,
1437
+ additional_metadata=additional_metadata,
1438
+ parent_outputs=parent_outputs,
1439
+ retry_count=retry_count,
1440
+ lifespan=lifespan,
1441
+ )
1442
+
1443
+ @property
1444
+ def is_async_function(self) -> bool:
1445
+ """
1446
+ Check if the task is an async function.
1447
+
1448
+ :returns: True if the task is an async function, False otherwise.
1449
+ """
1450
+ return self._task.is_async_function
1451
+
1452
+ def get_run_ref(self, run_id: str) -> TaskRunRef[TWorkflowInput, R]:
1453
+ """
1454
+ Get a reference to a task run by its run ID.
1455
+
1456
+ :param run_id: The ID of the run to get the reference for.
1457
+ :returns: A `TaskRunRef` object representing the reference to the task run.
1458
+ """
1459
+ wrr = self._workflow.client._client.runs.get_run_ref(run_id)
1460
+ return TaskRunRef[TWorkflowInput, R](self, wrr)
1461
+
1462
+ async def aio_get_result(self, run_id: str) -> R:
1463
+ """
1464
+ Get the result of a task run by its run ID.
1465
+
1466
+ :param run_id: The ID of the run to get the result for.
1467
+ :returns: The result of the task run.
1468
+ """
1469
+ run_ref = self.get_run_ref(run_id)
1470
+
1471
+ return await run_ref.aio_result()
1472
+
1473
+ def get_result(self, run_id: str) -> R:
1474
+ """
1475
+ Get the result of a task run by its run ID.
1476
+
1477
+ :param run_id: The ID of the run to get the result for.
1478
+ :returns: The result of the task run.
1479
+ """
1480
+ run_ref = self.get_run_ref(run_id)
1481
+
1482
+ return run_ref.result()
@@ -2,10 +2,7 @@ from typing import Any, TypeVar, cast, overload
2
2
 
3
3
  T = TypeVar("T")
4
4
  K = TypeVar("K")
5
-
6
-
7
- @overload
8
- def remove_null_unicode_character(data: str, replacement: str = "") -> str: ...
5
+ R = TypeVar("R")
9
6
 
10
7
 
11
8
  @overload
@@ -24,9 +21,13 @@ def remove_null_unicode_character(
24
21
  ) -> tuple[T, ...]: ...
25
22
 
26
23
 
24
+ @overload
25
+ def remove_null_unicode_character(data: R, replacement: str = "") -> R: ...
26
+
27
+
27
28
  def remove_null_unicode_character(
28
- data: str | dict[K, T] | list[T] | tuple[T, ...], replacement: str = ""
29
- ) -> str | dict[K, T] | list[T] | tuple[T, ...]:
29
+ data: dict[K, T] | list[T] | tuple[T, ...] | R, replacement: str = ""
30
+ ) -> str | dict[K, T] | list[T] | tuple[T, ...] | R:
30
31
  """
31
32
  Recursively traverse a dictionary (a task's output) and remove the unicode escape sequence \\u0000 which will cause unexpected behavior in Hatchet.
32
33
 
@@ -36,7 +37,6 @@ def remove_null_unicode_character(
36
37
  :param replacement: The string to replace \\u0000 with.
37
38
 
38
39
  :return: The same dictionary with all \\u0000 characters removed from strings, and nested dictionaries/lists processed recursively.
39
- :raises TypeError: If the input is not a string, dictionary, list, or tuple.
40
40
  """
41
41
  if isinstance(data, str):
42
42
  return data.replace("\u0000", replacement)
@@ -57,6 +57,4 @@ def remove_null_unicode_character(
57
57
  remove_null_unicode_character(cast(Any, item), replacement) for item in data
58
58
  )
59
59
 
60
- raise TypeError(
61
- f"Unsupported type {type(data)}. Expected str, dict, list, or tuple."
62
- )
60
+ return data
@@ -132,8 +132,8 @@ class WorkerActionListenerProcess:
132
132
  )
133
133
 
134
134
  logger.debug(f"acquired action listener: {self.listener.worker_id}")
135
- except grpc.RpcError as rpc_error:
136
- logger.error(f"could not start action listener: {rpc_error}")
135
+ except grpc.RpcError:
136
+ logger.exception("could not start action listener")
137
137
  return
138
138
 
139
139
  # Start both loops as background tasks
@@ -168,7 +168,7 @@ class WorkerActionListenerProcess:
168
168
  count += 1
169
169
 
170
170
  if count > 0:
171
- logger.warning(f"{BLOCKED_THREAD_WARNING}: Waiting Steps {count}")
171
+ logger.warning(f"{BLOCKED_THREAD_WARNING} Waiting Steps {count}")
172
172
  await asyncio.sleep(1)
173
173
 
174
174
  async def send_event(self, event: ActionEvent, retry_attempt: int = 1) -> None:
@@ -188,7 +188,7 @@ class WorkerActionListenerProcess:
188
188
  )
189
189
  if diff > 0.1:
190
190
  logger.warning(
191
- f"{BLOCKED_THREAD_WARNING}: time to start: {diff}s"
191
+ f"{BLOCKED_THREAD_WARNING} time to start: {diff}s"
192
192
  )
193
193
  else:
194
194
  logger.debug(f"start time: {diff}")
@@ -225,9 +225,9 @@ class WorkerActionListenerProcess:
225
225
  )
226
226
  case _:
227
227
  logger.error("unknown action type for event send")
228
- except Exception as e:
229
- logger.error(
230
- f"could not send action event ({retry_attempt}/{ACTION_EVENT_RETRY_COUNT}): {e}"
228
+ except Exception:
229
+ logger.exception(
230
+ f"could not send action event ({retry_attempt}/{ACTION_EVENT_RETRY_COUNT})"
231
231
  )
232
232
  if retry_attempt <= ACTION_EVENT_RETRY_COUNT:
233
233
  await exp_backoff_sleep(retry_attempt, 1)
@@ -291,11 +291,11 @@ class WorkerActionListenerProcess:
291
291
  )
292
292
  try:
293
293
  self.action_queue.put(action)
294
- except Exception as e:
295
- logger.error(f"error putting action: {e}")
294
+ except Exception:
295
+ logger.exception("error putting action")
296
296
 
297
- except Exception as e:
298
- logger.error(f"error in action loop: {e}")
297
+ except Exception:
298
+ logger.exception("error in action loop")
299
299
  finally:
300
300
  logger.info("action loop closed")
301
301
  if not self.killing:
@@ -40,6 +40,7 @@ from hatchet_sdk.logger import logger
40
40
  from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
41
41
  from hatchet_sdk.runnables.contextvars import (
42
42
  ctx_action_key,
43
+ ctx_additional_metadata,
43
44
  ctx_step_run_id,
44
45
  ctx_worker_id,
45
46
  ctx_workflow_run_id,
@@ -54,6 +55,8 @@ from hatchet_sdk.worker.action_listener_process import ActionEvent
54
55
  from hatchet_sdk.worker.runner.utils.capture_logs import (
55
56
  AsyncLogSender,
56
57
  ContextVarToCopy,
58
+ ContextVarToCopyDict,
59
+ ContextVarToCopyStr,
57
60
  copy_context_vars,
58
61
  )
59
62
 
@@ -295,6 +298,7 @@ class Runner:
295
298
  ctx_workflow_run_id.set(action.workflow_run_id)
296
299
  ctx_worker_id.set(action.worker_id)
297
300
  ctx_action_key.set(action.key)
301
+ ctx_additional_metadata.set(action.additional_metadata)
298
302
 
299
303
  try:
300
304
  if task.is_async_function:
@@ -305,20 +309,34 @@ class Runner:
305
309
  copy_context_vars,
306
310
  [
307
311
  ContextVarToCopy(
308
- name="ctx_step_run_id",
309
- value=action.step_run_id,
312
+ var=ContextVarToCopyStr(
313
+ name="ctx_step_run_id",
314
+ value=action.step_run_id,
315
+ )
316
+ ),
317
+ ContextVarToCopy(
318
+ var=ContextVarToCopyStr(
319
+ name="ctx_workflow_run_id",
320
+ value=action.workflow_run_id,
321
+ )
310
322
  ),
311
323
  ContextVarToCopy(
312
- name="ctx_workflow_run_id",
313
- value=action.workflow_run_id,
324
+ var=ContextVarToCopyStr(
325
+ name="ctx_worker_id",
326
+ value=action.worker_id,
327
+ )
314
328
  ),
315
329
  ContextVarToCopy(
316
- name="ctx_worker_id",
317
- value=action.worker_id,
330
+ var=ContextVarToCopyStr(
331
+ name="ctx_action_key",
332
+ value=action.key,
333
+ )
318
334
  ),
319
335
  ContextVarToCopy(
320
- name="ctx_action_key",
321
- value=action.key,
336
+ var=ContextVarToCopyDict(
337
+ name="ctx_additional_metadata",
338
+ value=action.additional_metadata,
339
+ )
322
340
  ),
323
341
  ],
324
342
  self.thread_action_func,
@@ -344,34 +362,34 @@ class Runner:
344
362
  "threads_daemon": sum(1 for t in self.thread_pool._threads if t.daemon),
345
363
  }
346
364
 
347
- logger.warning("Thread pool detailed status %s", thread_pool_details)
365
+ logger.warning("thread pool detailed status %s", thread_pool_details)
348
366
 
349
367
  async def _start_monitoring(self) -> None:
350
- logger.debug("Thread pool monitoring started")
368
+ logger.debug("thread pool monitoring started")
351
369
  try:
352
370
  while True:
353
371
  await self.log_thread_pool_status()
354
372
 
355
373
  for key in self.threads:
356
374
  if key not in self.tasks:
357
- logger.debug(f"Potential zombie thread found for key {key}")
375
+ logger.debug(f"potential zombie thread found for key {key}")
358
376
 
359
377
  for key, task in self.tasks.items():
360
378
  if task.done() and key in self.threads:
361
379
  logger.debug(
362
- f"Task is done but thread still exists for key {key}"
380
+ f"task is done but thread still exists for key {key}"
363
381
  )
364
382
 
365
383
  await asyncio.sleep(60)
366
384
  except asyncio.CancelledError:
367
- logger.warning("Thread pool monitoring task cancelled")
385
+ logger.warning("thread pool monitoring task cancelled")
368
386
  except Exception as e:
369
- logger.exception(f"Error in thread pool monitoring: {e}")
387
+ logger.exception(f"error in thread pool monitoring: {e}")
370
388
 
371
389
  def start_background_monitoring(self) -> None:
372
390
  loop = asyncio.get_event_loop()
373
391
  self.monitoring_task = loop.create_task(self._start_monitoring())
374
- logger.debug("Started thread pool monitoring background task")
392
+ logger.debug("started thread pool monitoring background task")
375
393
 
376
394
  def cleanup_run_id(self, key: ActionKey) -> None:
377
395
  if key in self.tasks:
@@ -503,7 +521,7 @@ class Runner:
503
521
 
504
522
  ident = cast(int, thread.ident)
505
523
 
506
- logger.info(f"Forcefully terminating thread {ident}")
524
+ logger.info(f"forcefully terminating thread {ident}")
507
525
 
508
526
  exc = ctypes.py_object(SystemExit)
509
527
  res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
@@ -516,13 +534,13 @@ class Runner:
516
534
  ctypes.pythonapi.PyThreadState_SetAsyncExc(thread.ident, 0)
517
535
  raise SystemError("PyThreadState_SetAsyncExc failed")
518
536
 
519
- logger.info(f"Successfully terminated thread {ident}")
537
+ logger.info(f"successfully terminated thread {ident}")
520
538
 
521
539
  # Immediately add a new thread to the thread pool, because we've actually killed a worker
522
540
  # in the ThreadPoolExecutor
523
541
  self.thread_pool.submit(lambda: None)
524
542
  except Exception as e:
525
- logger.exception(f"Failed to terminate thread: {e}")
543
+ logger.exception(f"failed to terminate thread: {e}")
526
544
 
527
545
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
528
546
  async def handle_cancel_action(self, action: Action) -> None:
@@ -546,7 +564,7 @@ class Runner:
546
564
  await asyncio.sleep(1)
547
565
 
548
566
  logger.warning(
549
- f"Thread {self.threads[key].ident} with key {key} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
567
+ f"thread {self.threads[key].ident} with key {key} is still running after cancellation. This could cause the thread pool to get blocked and prevent new tasks from running."
550
568
  )
551
569
  finally:
552
570
  self.cleanup_run_id(key)
@@ -568,8 +586,8 @@ class Runner:
568
586
 
569
587
  try:
570
588
  serialized_output = json.dumps(output, default=str)
571
- except Exception as e:
572
- logger.error(f"Could not serialize output: {e}")
589
+ except Exception:
590
+ logger.exception("could not serialize output")
573
591
  serialized_output = str(output)
574
592
 
575
593
  if "\\u0000" in serialized_output:
@@ -5,29 +5,42 @@ from collections.abc import Awaitable, Callable
5
5
  from io import StringIO
6
6
  from typing import Literal, ParamSpec, TypeVar
7
7
 
8
- from pydantic import BaseModel
8
+ from pydantic import BaseModel, Field
9
9
 
10
10
  from hatchet_sdk.clients.events import EventClient
11
11
  from hatchet_sdk.logger import logger
12
12
  from hatchet_sdk.runnables.contextvars import (
13
13
  ctx_action_key,
14
+ ctx_additional_metadata,
14
15
  ctx_step_run_id,
15
16
  ctx_worker_id,
16
17
  ctx_workflow_run_id,
17
18
  )
18
- from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE
19
+ from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE, JSONSerializableMapping
19
20
 
20
21
  T = TypeVar("T")
21
22
  P = ParamSpec("P")
22
23
 
23
24
 
24
- class ContextVarToCopy(BaseModel):
25
+ class ContextVarToCopyStr(BaseModel):
25
26
  name: Literal[
26
- "ctx_workflow_run_id", "ctx_step_run_id", "ctx_action_key", "ctx_worker_id"
27
+ "ctx_workflow_run_id",
28
+ "ctx_step_run_id",
29
+ "ctx_action_key",
30
+ "ctx_worker_id",
27
31
  ]
28
32
  value: str | None
29
33
 
30
34
 
35
+ class ContextVarToCopyDict(BaseModel):
36
+ name: Literal["ctx_additional_metadata"]
37
+ value: JSONSerializableMapping | None
38
+
39
+
40
+ class ContextVarToCopy(BaseModel):
41
+ var: ContextVarToCopyStr | ContextVarToCopyDict = Field(discriminator="name")
42
+
43
+
31
44
  def copy_context_vars(
32
45
  ctx_vars: list[ContextVarToCopy],
33
46
  func: Callable[P, T],
@@ -35,16 +48,18 @@ def copy_context_vars(
35
48
  **kwargs: P.kwargs,
36
49
  ) -> T:
37
50
  for var in ctx_vars:
38
- if var.name == "ctx_workflow_run_id":
39
- ctx_workflow_run_id.set(var.value)
40
- elif var.name == "ctx_step_run_id":
41
- ctx_step_run_id.set(var.value)
42
- elif var.name == "ctx_action_key":
43
- ctx_action_key.set(var.value)
44
- elif var.name == "ctx_worker_id":
45
- ctx_worker_id.set(var.value)
51
+ if var.var.name == "ctx_workflow_run_id":
52
+ ctx_workflow_run_id.set(var.var.value)
53
+ elif var.var.name == "ctx_step_run_id":
54
+ ctx_step_run_id.set(var.var.value)
55
+ elif var.var.name == "ctx_action_key":
56
+ ctx_action_key.set(var.var.value)
57
+ elif var.var.name == "ctx_worker_id":
58
+ ctx_worker_id.set(var.var.value)
59
+ elif var.var.name == "ctx_additional_metadata":
60
+ ctx_additional_metadata.set(var.var.value or {})
46
61
  else:
47
- raise ValueError(f"Unknown context variable name: {var.name}")
62
+ raise ValueError(f"Unknown context variable name: {var.var.name}")
48
63
 
49
64
  return func(*args, **kwargs)
50
65
 
@@ -73,13 +88,13 @@ class AsyncLogSender:
73
88
  step_run_id=record.step_run_id,
74
89
  )
75
90
  except Exception:
76
- logger.exception("Failed to send log to Hatchet")
91
+ logger.exception("failed to send log to Hatchet")
77
92
 
78
93
  def publish(self, record: LogRecord | STOP_LOOP_TYPE) -> None:
79
94
  try:
80
95
  self.q.put_nowait(record)
81
96
  except asyncio.QueueFull:
82
- logger.warning("Log queue is full, dropping log message")
97
+ logger.warning("log queue is full, dropping log message")
83
98
 
84
99
 
85
100
  class CustomLogHandler(logging.StreamHandler): # type: ignore[type-arg]