hatchet-sdk 1.15.3__py3-none-any.whl → 1.16.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (36) hide show
  1. hatchet_sdk/__init__.py +4 -0
  2. hatchet_sdk/client.py +2 -0
  3. hatchet_sdk/clients/admin.py +3 -1
  4. hatchet_sdk/clients/dispatcher/action_listener.py +13 -13
  5. hatchet_sdk/clients/event_ts.py +1 -1
  6. hatchet_sdk/clients/listeners/pooled_listener.py +4 -4
  7. hatchet_sdk/clients/rest/__init__.py +8 -0
  8. hatchet_sdk/clients/rest/api/__init__.py +1 -0
  9. hatchet_sdk/clients/rest/api/cel_api.py +334 -0
  10. hatchet_sdk/clients/rest/api/task_api.py +12 -10
  11. hatchet_sdk/clients/rest/models/__init__.py +7 -0
  12. hatchet_sdk/clients/rest/models/v1_cancelled_tasks.py +87 -0
  13. hatchet_sdk/clients/rest/models/v1_cel_debug_error_response.py +93 -0
  14. hatchet_sdk/clients/rest/models/v1_cel_debug_request.py +108 -0
  15. hatchet_sdk/clients/rest/models/v1_cel_debug_response.py +100 -0
  16. hatchet_sdk/clients/rest/models/v1_cel_debug_response_status.py +37 -0
  17. hatchet_sdk/clients/rest/models/v1_cel_debug_success_response.py +102 -0
  18. hatchet_sdk/clients/rest/models/v1_replayed_tasks.py +87 -0
  19. hatchet_sdk/clients/rest/tenacity_utils.py +1 -1
  20. hatchet_sdk/clients/v1/api_client.py +3 -3
  21. hatchet_sdk/context/context.py +6 -6
  22. hatchet_sdk/features/cel.py +99 -0
  23. hatchet_sdk/features/runs.py +21 -1
  24. hatchet_sdk/hatchet.py +8 -0
  25. hatchet_sdk/opentelemetry/instrumentor.py +3 -3
  26. hatchet_sdk/runnables/contextvars.py +4 -0
  27. hatchet_sdk/runnables/task.py +133 -0
  28. hatchet_sdk/runnables/workflow.py +218 -11
  29. hatchet_sdk/worker/action_listener_process.py +11 -11
  30. hatchet_sdk/worker/runner/runner.py +39 -21
  31. hatchet_sdk/worker/runner/utils/capture_logs.py +30 -15
  32. hatchet_sdk/worker/worker.py +11 -14
  33. {hatchet_sdk-1.15.3.dist-info → hatchet_sdk-1.16.1.dist-info}/METADATA +1 -1
  34. {hatchet_sdk-1.15.3.dist-info → hatchet_sdk-1.16.1.dist-info}/RECORD +36 -27
  35. {hatchet_sdk-1.15.3.dist-info → hatchet_sdk-1.16.1.dist-info}/WHEEL +0 -0
  36. {hatchet_sdk-1.15.3.dist-info → hatchet_sdk-1.16.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,99 @@
1
+ import asyncio
2
+ from typing import Literal
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+ from hatchet_sdk.clients.rest.api.cel_api import CELApi
7
+ from hatchet_sdk.clients.rest.api_client import ApiClient
8
+ from hatchet_sdk.clients.rest.models.v1_cel_debug_request import V1CELDebugRequest
9
+ from hatchet_sdk.clients.rest.models.v1_cel_debug_response_status import (
10
+ V1CELDebugResponseStatus,
11
+ )
12
+ from hatchet_sdk.clients.v1.api_client import BaseRestClient, retry
13
+ from hatchet_sdk.utils.typing import JSONSerializableMapping
14
+
15
+
16
+ class CELSuccess(BaseModel):
17
+ status: Literal["success"] = "success"
18
+ output: bool
19
+
20
+
21
+ class CELFailure(BaseModel):
22
+ status: Literal["failure"] = "failure"
23
+ error: str
24
+
25
+
26
+ class CELEvaluationResult(BaseModel):
27
+ result: CELSuccess | CELFailure = Field(discriminator="status")
28
+
29
+
30
+ class CELClient(BaseRestClient):
31
+ """
32
+ The CEL client is a client for debugging CEL expressions within Hatchet
33
+ """
34
+
35
+ def _ca(self, client: ApiClient) -> CELApi:
36
+ return CELApi(client)
37
+
38
+ @retry
39
+ def debug(
40
+ self,
41
+ expression: str,
42
+ input: JSONSerializableMapping,
43
+ additional_metadata: JSONSerializableMapping | None = None,
44
+ filter_payload: JSONSerializableMapping | None = None,
45
+ ) -> CELEvaluationResult:
46
+ """
47
+ Debug a CEL expression with the provided input, filter payload, and optional metadata. Useful for testing and validating CEL expressions and debugging issues in production.
48
+
49
+ :param expression: The CEL expression to debug.
50
+ :param input: The input, which simulates the workflow run input.
51
+ :param additional_metadata: Additional metadata, which simulates metadata that could be sent with an event or a workflow run
52
+ :param filter_payload: The filter payload, which simulates a payload set on a previous-created filter
53
+
54
+ :raises ValueError: If no response is received from the CEL debug API.
55
+
56
+ :return: A V1CELDebugErrorResponse or V1CELDebugSuccessResponse containing the result of the debug operation.
57
+ """
58
+ request = V1CELDebugRequest(
59
+ expression=expression,
60
+ input=input,
61
+ additionalMetadata=additional_metadata,
62
+ filterPayload=filter_payload,
63
+ )
64
+ with self.client() as client:
65
+ result = self._ca(client).v1_cel_debug(
66
+ tenant=self.client_config.tenant_id, v1_cel_debug_request=request
67
+ )
68
+
69
+ if result.status == V1CELDebugResponseStatus.ERROR:
70
+ if result.error is None:
71
+ raise ValueError("No error message received from CEL debug API.")
72
+
73
+ return CELEvaluationResult(result=CELFailure(error=result.error))
74
+
75
+ if result.output is None:
76
+ raise ValueError("No output received from CEL debug API.")
77
+
78
+ return CELEvaluationResult(result=CELSuccess(output=result.output))
79
+
80
+ async def aio_debug(
81
+ self,
82
+ expression: str,
83
+ input: JSONSerializableMapping,
84
+ additional_metadata: JSONSerializableMapping | None = None,
85
+ filter_payload: JSONSerializableMapping | None = None,
86
+ ) -> CELEvaluationResult:
87
+ """
88
+ Debug a CEL expression with the provided input, filter payload, and optional metadata. Useful for testing and validating CEL expressions and debugging issues in production.
89
+
90
+ :param expression: The CEL expression to debug.
91
+ :param input: The input, which simulates the workflow run input.
92
+ :param additional_metadata: Additional metadata, which simulates metadata that could be sent with an event or a workflow run
93
+ :param filter_payload: The filter payload, which simulates a payload set on a previous-created filter
94
+
95
+ :return: A V1CELDebugErrorResponse or V1CELDebugSuccessResponse containing the result of the debug operation.
96
+ """
97
+ return await asyncio.to_thread(
98
+ self.debug, expression, input, additional_metadata, filter_payload
99
+ )
@@ -119,6 +119,26 @@ class RunsClient(BaseRestClient):
119
119
  def _ta(self, client: ApiClient) -> TaskApi:
120
120
  return TaskApi(client)
121
121
 
122
+ @retry
123
+ def get_task_run(self, task_run_id: str) -> V1TaskSummary:
124
+ """
125
+ Get task run details for a given task run ID.
126
+
127
+ :param task_run_id: The ID of the task run to retrieve details for.
128
+ :return: Task run details for the specified task run ID.
129
+ """
130
+ with self.client() as client:
131
+ return self._ta(client).v1_task_get(task_run_id)
132
+
133
+ async def aio_get_task_run(self, task_run_id: str) -> V1TaskSummary:
134
+ """
135
+ Get task run details for a given task run ID.
136
+
137
+ :param task_run_id: The ID of the task run to retrieve details for.
138
+ :return: Task run details for the specified task run ID.
139
+ """
140
+ return await asyncio.to_thread(self.get_task_run, task_run_id)
141
+
122
142
  @retry
123
143
  def get(self, workflow_run_id: str) -> V1WorkflowRunDetails:
124
144
  """
@@ -148,7 +168,7 @@ class RunsClient(BaseRestClient):
148
168
  :return: The task status
149
169
  """
150
170
  with self.client() as client:
151
- return self._wra(client).v1_workflow_run_get_status(str(workflow_run_id))
171
+ return self._wra(client).v1_workflow_run_get_status(workflow_run_id)
152
172
 
153
173
  async def aio_get_status(self, workflow_run_id: str) -> V1TaskStatus:
154
174
  """
hatchet_sdk/hatchet.py CHANGED
@@ -12,6 +12,7 @@ from hatchet_sdk.clients.events import EventClient
12
12
  from hatchet_sdk.clients.listeners.run_event_listener import RunEventListenerClient
13
13
  from hatchet_sdk.clients.rest.models.tenant_version import TenantVersion
14
14
  from hatchet_sdk.config import ClientConfig
15
+ from hatchet_sdk.features.cel import CELClient
15
16
  from hatchet_sdk.features.cron import CronClient
16
17
  from hatchet_sdk.features.filters import FiltersClient
17
18
  from hatchet_sdk.features.logs import LogsClient
@@ -66,6 +67,13 @@ class Hatchet:
66
67
  "🚨⚠️‼️ YOU ARE USING A V0 ENGINE WITH A V1 SDK, WHICH IS NOT SUPPORTED. PLEASE UPGRADE YOUR ENGINE TO V1.🚨⚠️‼️"
67
68
  )
68
69
 
70
+ @property
71
+ def cel(self) -> CELClient:
72
+ """
73
+ The CEL client is a client for interacting with Hatchet's CEL API.
74
+ """
75
+ return self._client.cel
76
+
69
77
  @property
70
78
  def cron(self) -> CronClient:
71
79
  """
@@ -64,7 +64,7 @@ OTEL_TRACEPARENT_KEY = "traceparent"
64
64
 
65
65
  def create_traceparent() -> str | None:
66
66
  logger.warning(
67
- "As of SDK version 1.11.0, you no longer need to call `create_traceparent` manually. The traceparent will be automatically created by the instrumentor and injected into the metadata of actions and events when appropriate. This method will be removed in a future version.",
67
+ "as of SDK version 1.11.0, you no longer need to call `create_traceparent` manually. The traceparent will be automatically created by the instrumentor and injected into the metadata of actions and events when appropriate. This method will be removed in a future version.",
68
68
  )
69
69
  return _create_traceparent()
70
70
 
@@ -91,7 +91,7 @@ def parse_carrier_from_metadata(
91
91
  metadata: JSONSerializableMapping | None,
92
92
  ) -> Context | None:
93
93
  logger.warning(
94
- "As of SDK version 1.11.0, you no longer need to call `parse_carrier_from_metadata` manually. This method will be removed in a future version.",
94
+ "as of SDK version 1.11.0, you no longer need to call `parse_carrier_from_metadata` manually. This method will be removed in a future version.",
95
95
  )
96
96
 
97
97
  return _parse_carrier_from_metadata(metadata)
@@ -133,7 +133,7 @@ def inject_traceparent_into_metadata(
133
133
  metadata: dict[str, str], traceparent: str | None = None
134
134
  ) -> dict[str, str]:
135
135
  logger.warning(
136
- "As of SDK version 1.11.0, you no longer need to call `inject_traceparent_into_metadata` manually. The traceparent will automatically be injected by the instrumentor. This method will be removed in a future version.",
136
+ "as of SDK version 1.11.0, you no longer need to call `inject_traceparent_into_metadata` manually. The traceparent will automatically be injected by the instrumentor. This method will be removed in a future version.",
137
137
  )
138
138
 
139
139
  return _inject_traceparent_into_metadata(metadata, traceparent)
@@ -4,6 +4,7 @@ from collections import Counter
4
4
  from contextvars import ContextVar
5
5
 
6
6
  from hatchet_sdk.runnables.action import ActionKey
7
+ from hatchet_sdk.utils.typing import JSONSerializableMapping
7
8
 
8
9
  ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
9
10
  "ctx_workflow_run_id", default=None
@@ -13,6 +14,9 @@ ctx_action_key: ContextVar[ActionKey | None] = ContextVar(
13
14
  )
14
15
  ctx_step_run_id: ContextVar[str | None] = ContextVar("ctx_step_run_id", default=None)
15
16
  ctx_worker_id: ContextVar[str | None] = ContextVar("ctx_worker_id", default=None)
17
+ ctx_additional_metadata: ContextVar[JSONSerializableMapping | None] = ContextVar(
18
+ "ctx_additional_metadata", default=None
19
+ )
16
20
 
17
21
  workflow_spawn_indices = Counter[ActionKey]()
18
22
  spawn_index_lock = asyncio.Lock()
@@ -11,6 +11,7 @@ from hatchet_sdk.conditions import (
11
11
  flatten_conditions,
12
12
  )
13
13
  from hatchet_sdk.context.context import Context, DurableContext
14
+ from hatchet_sdk.context.worker_context import WorkerContext
14
15
  from hatchet_sdk.contracts.v1.shared.condition_pb2 import TaskConditions
15
16
  from hatchet_sdk.contracts.v1.workflows_pb2 import (
16
17
  CreateTaskOpts,
@@ -19,6 +20,7 @@ from hatchet_sdk.contracts.v1.workflows_pb2 import (
19
20
  )
20
21
  from hatchet_sdk.runnables.types import (
21
22
  ConcurrencyExpression,
23
+ EmptyModel,
22
24
  R,
23
25
  StepType,
24
26
  TWorkflowInput,
@@ -30,9 +32,11 @@ from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_exp
30
32
  from hatchet_sdk.utils.typing import (
31
33
  AwaitableLike,
32
34
  CoroutineLike,
35
+ JSONSerializableMapping,
33
36
  TaskIOValidator,
34
37
  is_basemodel_subclass,
35
38
  )
39
+ from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender
36
40
 
37
41
  if TYPE_CHECKING:
38
42
  from hatchet_sdk.runnables.workflow import Workflow
@@ -186,3 +190,132 @@ class Task(Generic[TWorkflowInput, R]):
186
190
  sleep_conditions=sleep_conditions,
187
191
  user_event_conditions=user_events,
188
192
  )
193
+
194
+ def _create_mock_context(
195
+ self,
196
+ input: TWorkflowInput | None,
197
+ additional_metadata: JSONSerializableMapping | None = None,
198
+ parent_outputs: dict[str, JSONSerializableMapping] | None = None,
199
+ retry_count: int = 0,
200
+ lifespan_context: Any = None,
201
+ ) -> Context | DurableContext:
202
+ from hatchet_sdk.runnables.action import Action, ActionPayload, ActionType
203
+
204
+ additional_metadata = additional_metadata or {}
205
+ parent_outputs = parent_outputs or {}
206
+
207
+ if input is None:
208
+ input = cast(TWorkflowInput, EmptyModel())
209
+
210
+ action_payload = ActionPayload(input=input.model_dump(), parents=parent_outputs)
211
+
212
+ action = Action(
213
+ tenant_id=self.workflow.client.config.tenant_id,
214
+ worker_id="mock-worker-id",
215
+ workflow_run_id="mock-workflow-run-id",
216
+ get_group_key_run_id="mock-get-group-key-run-id",
217
+ job_id="mock-job-id",
218
+ job_name="mock-job-name",
219
+ job_run_id="mock-job-run-id",
220
+ step_id="mock-step-id",
221
+ step_run_id="mock-step-run-id",
222
+ action_id="mock:action",
223
+ action_payload=action_payload,
224
+ action_type=ActionType.START_STEP_RUN,
225
+ retry_count=retry_count,
226
+ additional_metadata=additional_metadata,
227
+ child_workflow_index=None,
228
+ child_workflow_key=None,
229
+ parent_workflow_run_id=None,
230
+ priority=1,
231
+ workflow_version_id="mock-workflow-version-id",
232
+ workflow_id="mock-workflow-id",
233
+ )
234
+
235
+ constructor = DurableContext if self.is_durable else Context
236
+
237
+ return constructor(
238
+ action=action,
239
+ dispatcher_client=self.workflow.client._client.dispatcher,
240
+ admin_client=self.workflow.client._client.admin,
241
+ event_client=self.workflow.client._client.event,
242
+ durable_event_listener=None,
243
+ worker=WorkerContext(
244
+ labels={}, client=self.workflow.client._client.dispatcher
245
+ ),
246
+ runs_client=self.workflow.client._client.runs,
247
+ lifespan_context=lifespan_context,
248
+ log_sender=AsyncLogSender(self.workflow.client._client.event),
249
+ )
250
+
251
+ def mock_run(
252
+ self,
253
+ input: TWorkflowInput | None = None,
254
+ additional_metadata: JSONSerializableMapping | None = None,
255
+ parent_outputs: dict[str, JSONSerializableMapping] | None = None,
256
+ retry_count: int = 0,
257
+ lifespan: Any = None,
258
+ ) -> R:
259
+ """
260
+ Mimic the execution of a task. This method is intended to be used to unit test
261
+ tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
262
+ tasks and `aio_mock_run` for async tasks.
263
+
264
+ :param input: The input to the task.
265
+ :param additional_metadata: Additional metadata to attach to the task.
266
+ :param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
267
+ :param retry_count: The number of times the task has been retried.
268
+ :param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
269
+
270
+ :return: The output of the task.
271
+ :raises TypeError: If the task is an async function and `mock_run` is called, or if the task is a sync function and `aio_mock_run` is called.
272
+ """
273
+
274
+ if self.is_async_function:
275
+ raise TypeError(
276
+ f"{self.name} is not a sync function. Use `aio_mock_run` instead."
277
+ )
278
+
279
+ ctx = self._create_mock_context(
280
+ input, additional_metadata, parent_outputs, retry_count, lifespan
281
+ )
282
+
283
+ return self.call(ctx)
284
+
285
+ async def aio_mock_run(
286
+ self,
287
+ input: TWorkflowInput | None = None,
288
+ additional_metadata: JSONSerializableMapping | None = None,
289
+ parent_outputs: dict[str, JSONSerializableMapping] | None = None,
290
+ retry_count: int = 0,
291
+ lifespan: Any = None,
292
+ ) -> R:
293
+ """
294
+ Mimic the execution of a task. This method is intended to be used to unit test
295
+ tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
296
+ tasks and `aio_mock_run` for async tasks.
297
+
298
+ :param input: The input to the task.
299
+ :param additional_metadata: Additional metadata to attach to the task.
300
+ :param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
301
+ :param retry_count: The number of times the task has been retried.
302
+ :param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
303
+
304
+ :return: The output of the task.
305
+ :raises TypeError: If the task is an async function and `mock_run` is called, or if the task is a sync function and `aio_mock_run` is called.
306
+ """
307
+
308
+ if not self.is_async_function:
309
+ raise TypeError(
310
+ f"{self.name} is not an async function. Use `mock_run` instead."
311
+ )
312
+
313
+ ctx = self._create_mock_context(
314
+ input,
315
+ additional_metadata,
316
+ parent_outputs,
317
+ retry_count,
318
+ lifespan,
319
+ )
320
+
321
+ return await self.aio_call(ctx)
@@ -2,7 +2,16 @@ import asyncio
2
2
  from collections.abc import Callable
3
3
  from datetime import datetime, timedelta
4
4
  from functools import cached_property
5
- from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_type_hints
5
+ from typing import (
6
+ TYPE_CHECKING,
7
+ Any,
8
+ Generic,
9
+ Literal,
10
+ TypeVar,
11
+ cast,
12
+ get_type_hints,
13
+ overload,
14
+ )
6
15
 
7
16
  from google.protobuf import timestamp_pb2
8
17
  from pydantic import BaseModel, model_validator
@@ -651,39 +660,83 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
651
660
 
652
661
  return await ref.aio_result()
653
662
 
663
+ def _get_result(
664
+ self, ref: WorkflowRunRef, return_exceptions: bool
665
+ ) -> dict[str, Any] | BaseException:
666
+ try:
667
+ return ref.result()
668
+ except Exception as e:
669
+ if return_exceptions:
670
+ return e
671
+ raise e
672
+
673
+ @overload
674
+ def run_many(
675
+ self,
676
+ workflows: list[WorkflowRunTriggerConfig],
677
+ return_exceptions: Literal[True],
678
+ ) -> list[dict[str, Any] | BaseException]: ...
679
+
680
+ @overload
654
681
  def run_many(
655
682
  self,
656
683
  workflows: list[WorkflowRunTriggerConfig],
657
- ) -> list[dict[str, Any]]:
684
+ return_exceptions: Literal[False] = False,
685
+ ) -> list[dict[str, Any]]: ...
686
+
687
+ def run_many(
688
+ self,
689
+ workflows: list[WorkflowRunTriggerConfig],
690
+ return_exceptions: bool = False,
691
+ ) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
658
692
  """
659
693
  Run a workflow in bulk and wait for all runs to complete.
660
694
  This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
661
695
 
662
696
  :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
697
+ :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
663
698
  :returns: A list of results for each workflow run.
664
699
  """
665
700
  refs = self.client._client.admin.run_workflows(
666
701
  workflows=workflows,
667
702
  )
668
703
 
669
- return [ref.result() for ref in refs]
704
+ return [self._get_result(ref, return_exceptions) for ref in refs]
670
705
 
706
+ @overload
671
707
  async def aio_run_many(
672
708
  self,
673
709
  workflows: list[WorkflowRunTriggerConfig],
674
- ) -> list[dict[str, Any]]:
710
+ return_exceptions: Literal[True],
711
+ ) -> list[dict[str, Any] | BaseException]: ...
712
+
713
+ @overload
714
+ async def aio_run_many(
715
+ self,
716
+ workflows: list[WorkflowRunTriggerConfig],
717
+ return_exceptions: Literal[False] = False,
718
+ ) -> list[dict[str, Any]]: ...
719
+
720
+ async def aio_run_many(
721
+ self,
722
+ workflows: list[WorkflowRunTriggerConfig],
723
+ return_exceptions: bool = False,
724
+ ) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
675
725
  """
676
726
  Run a workflow in bulk and wait for all runs to complete.
677
727
  This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
678
728
 
679
729
  :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
730
+ :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
680
731
  :returns: A list of results for each workflow run.
681
732
  """
682
733
  refs = await self.client._client.admin.aio_run_workflows(
683
734
  workflows=workflows,
684
735
  )
685
736
 
686
- return await asyncio.gather(*[ref.aio_result() for ref in refs])
737
+ return await asyncio.gather(
738
+ *[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
739
+ )
687
740
 
688
741
  def run_many_no_wait(
689
742
  self,
@@ -946,7 +999,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
946
999
 
947
1000
  :param backoff_max_seconds: The maximum number of seconds to allow retries with exponential backoff to continue.
948
1001
 
949
- :param concurrency: A list of concurrency expressions for the on-success task.
1002
+ :param concurrency: A list of concurrency expressions for the on-failure task.
950
1003
 
951
1004
  :returns: A decorator which creates a `Task` object.
952
1005
  """
@@ -1137,7 +1190,18 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
1137
1190
 
1138
1191
  self.config = self._workflow.config
1139
1192
 
1140
- def _extract_result(self, result: dict[str, Any]) -> R:
1193
+ @overload
1194
+ def _extract_result(self, result: dict[str, Any]) -> R: ...
1195
+
1196
+ @overload
1197
+ def _extract_result(self, result: BaseException) -> BaseException: ...
1198
+
1199
+ def _extract_result(
1200
+ self, result: dict[str, Any] | BaseException
1201
+ ) -> R | BaseException:
1202
+ if isinstance(result, BaseException):
1203
+ return result
1204
+
1141
1205
  output = result.get(self._task.name)
1142
1206
 
1143
1207
  if not self._output_validator:
@@ -1217,30 +1281,72 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
1217
1281
 
1218
1282
  return TaskRunRef[TWorkflowInput, R](self, ref)
1219
1283
 
1220
- def run_many(self, workflows: list[WorkflowRunTriggerConfig]) -> list[R]:
1284
+ @overload
1285
+ def run_many(
1286
+ self,
1287
+ workflows: list[WorkflowRunTriggerConfig],
1288
+ return_exceptions: Literal[True],
1289
+ ) -> list[R | BaseException]: ...
1290
+
1291
+ @overload
1292
+ def run_many(
1293
+ self,
1294
+ workflows: list[WorkflowRunTriggerConfig],
1295
+ return_exceptions: Literal[False] = False,
1296
+ ) -> list[R]: ...
1297
+
1298
+ def run_many(
1299
+ self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
1300
+ ) -> list[R] | list[R | BaseException]:
1221
1301
  """
1222
1302
  Run a workflow in bulk and wait for all runs to complete.
1223
1303
  This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
1224
1304
 
1225
1305
  :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
1306
+ :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
1226
1307
  :returns: A list of results for each workflow run.
1227
1308
  """
1228
1309
  return [
1229
1310
  self._extract_result(result)
1230
- for result in self._workflow.run_many(workflows)
1311
+ for result in self._workflow.run_many(
1312
+ workflows,
1313
+ ## hack: typing needs literal
1314
+ True if return_exceptions else False, # noqa: SIM210
1315
+ )
1231
1316
  ]
1232
1317
 
1233
- async def aio_run_many(self, workflows: list[WorkflowRunTriggerConfig]) -> list[R]:
1318
+ @overload
1319
+ async def aio_run_many(
1320
+ self,
1321
+ workflows: list[WorkflowRunTriggerConfig],
1322
+ return_exceptions: Literal[True],
1323
+ ) -> list[R | BaseException]: ...
1324
+
1325
+ @overload
1326
+ async def aio_run_many(
1327
+ self,
1328
+ workflows: list[WorkflowRunTriggerConfig],
1329
+ return_exceptions: Literal[False] = False,
1330
+ ) -> list[R]: ...
1331
+
1332
+ async def aio_run_many(
1333
+ self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
1334
+ ) -> list[R] | list[R | BaseException]:
1234
1335
  """
1235
1336
  Run a workflow in bulk and wait for all runs to complete.
1236
1337
  This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
1237
1338
 
1238
1339
  :param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
1340
+ :param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
1239
1341
  :returns: A list of results for each workflow run.
1240
1342
  """
1241
1343
  return [
1242
1344
  self._extract_result(result)
1243
- for result in await self._workflow.aio_run_many(workflows)
1345
+ for result in await self._workflow.aio_run_many(
1346
+ workflows,
1347
+ ## hack: typing needs literal
1348
+ True if return_exceptions else False, # noqa: SIM210
1349
+ )
1244
1350
  ]
1245
1351
 
1246
1352
  def run_many_no_wait(
@@ -1273,3 +1379,104 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
1273
1379
  refs = await self._workflow.aio_run_many_no_wait(workflows)
1274
1380
 
1275
1381
  return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
1382
+
1383
+ def mock_run(
1384
+ self,
1385
+ input: TWorkflowInput | None = None,
1386
+ additional_metadata: JSONSerializableMapping | None = None,
1387
+ parent_outputs: dict[str, JSONSerializableMapping] | None = None,
1388
+ retry_count: int = 0,
1389
+ lifespan: Any = None,
1390
+ ) -> R:
1391
+ """
1392
+ Mimic the execution of a task. This method is intended to be used to unit test
1393
+ tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
1394
+ tasks and `aio_mock_run` for async tasks.
1395
+
1396
+ :param input: The input to the task.
1397
+ :param additional_metadata: Additional metadata to attach to the task.
1398
+ :param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
1399
+ :param retry_count: The number of times the task has been retried.
1400
+ :param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
1401
+
1402
+ :return: The output of the task.
1403
+ """
1404
+
1405
+ return self._task.mock_run(
1406
+ input=input,
1407
+ additional_metadata=additional_metadata,
1408
+ parent_outputs=parent_outputs,
1409
+ retry_count=retry_count,
1410
+ lifespan=lifespan,
1411
+ )
1412
+
1413
+ async def aio_mock_run(
1414
+ self,
1415
+ input: TWorkflowInput | None = None,
1416
+ additional_metadata: JSONSerializableMapping | None = None,
1417
+ parent_outputs: dict[str, JSONSerializableMapping] | None = None,
1418
+ retry_count: int = 0,
1419
+ lifespan: Any = None,
1420
+ ) -> R:
1421
+ """
1422
+ Mimic the execution of a task. This method is intended to be used to unit test
1423
+ tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
1424
+ tasks and `aio_mock_run` for async tasks.
1425
+
1426
+ :param input: The input to the task.
1427
+ :param additional_metadata: Additional metadata to attach to the task.
1428
+ :param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
1429
+ :param retry_count: The number of times the task has been retried.
1430
+ :param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
1431
+
1432
+ :return: The output of the task.
1433
+ """
1434
+
1435
+ return await self._task.aio_mock_run(
1436
+ input=input,
1437
+ additional_metadata=additional_metadata,
1438
+ parent_outputs=parent_outputs,
1439
+ retry_count=retry_count,
1440
+ lifespan=lifespan,
1441
+ )
1442
+
1443
+ @property
1444
+ def is_async_function(self) -> bool:
1445
+ """
1446
+ Check if the task is an async function.
1447
+
1448
+ :returns: True if the task is an async function, False otherwise.
1449
+ """
1450
+ return self._task.is_async_function
1451
+
1452
+ def get_run_ref(self, run_id: str) -> TaskRunRef[TWorkflowInput, R]:
1453
+ """
1454
+ Get a reference to a task run by its run ID.
1455
+
1456
+ :param run_id: The ID of the run to get the reference for.
1457
+ :returns: A `TaskRunRef` object representing the reference to the task run.
1458
+ """
1459
+ wrr = self._workflow.client._client.runs.get_run_ref(run_id)
1460
+ return TaskRunRef[TWorkflowInput, R](self, wrr)
1461
+
1462
+ async def aio_get_result(self, run_id: str) -> R:
1463
+ """
1464
+ Get the result of a task run by its run ID.
1465
+
1466
+ :param run_id: The ID of the run to get the result for.
1467
+ :returns: The result of the task run.
1468
+ """
1469
+ run_ref = self.get_run_ref(run_id)
1470
+
1471
+ return await run_ref.aio_result()
1472
+
1473
+ def get_result(self, run_id: str) -> R:
1474
+ """
1475
+ Get the result of a task run by its run ID.
1476
+
1477
+ :param run_id: The ID of the run to get the result for.
1478
+ :returns: The result of the task run.
1479
+ """
1480
+ run_ref = self.get_run_ref(run_id)
1481
+
1482
+ return run_ref.result()