hatchet-sdk 1.15.3__py3-none-any.whl → 1.16.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hatchet-sdk might be problematic. Click here for more details.
- hatchet_sdk/__init__.py +4 -0
- hatchet_sdk/client.py +2 -0
- hatchet_sdk/clients/admin.py +3 -1
- hatchet_sdk/clients/dispatcher/action_listener.py +13 -13
- hatchet_sdk/clients/event_ts.py +1 -1
- hatchet_sdk/clients/listeners/pooled_listener.py +4 -4
- hatchet_sdk/clients/rest/__init__.py +8 -0
- hatchet_sdk/clients/rest/api/__init__.py +1 -0
- hatchet_sdk/clients/rest/api/cel_api.py +334 -0
- hatchet_sdk/clients/rest/api/task_api.py +12 -10
- hatchet_sdk/clients/rest/models/__init__.py +7 -0
- hatchet_sdk/clients/rest/models/v1_cancelled_tasks.py +87 -0
- hatchet_sdk/clients/rest/models/v1_cel_debug_error_response.py +93 -0
- hatchet_sdk/clients/rest/models/v1_cel_debug_request.py +108 -0
- hatchet_sdk/clients/rest/models/v1_cel_debug_response.py +100 -0
- hatchet_sdk/clients/rest/models/v1_cel_debug_response_status.py +37 -0
- hatchet_sdk/clients/rest/models/v1_cel_debug_success_response.py +102 -0
- hatchet_sdk/clients/rest/models/v1_replayed_tasks.py +87 -0
- hatchet_sdk/clients/rest/tenacity_utils.py +1 -1
- hatchet_sdk/clients/v1/api_client.py +3 -3
- hatchet_sdk/context/context.py +6 -6
- hatchet_sdk/features/cel.py +99 -0
- hatchet_sdk/features/runs.py +21 -1
- hatchet_sdk/hatchet.py +8 -0
- hatchet_sdk/opentelemetry/instrumentor.py +3 -3
- hatchet_sdk/runnables/contextvars.py +4 -0
- hatchet_sdk/runnables/task.py +133 -0
- hatchet_sdk/runnables/workflow.py +218 -11
- hatchet_sdk/worker/action_listener_process.py +11 -11
- hatchet_sdk/worker/runner/runner.py +39 -21
- hatchet_sdk/worker/runner/utils/capture_logs.py +30 -15
- hatchet_sdk/worker/worker.py +11 -14
- {hatchet_sdk-1.15.3.dist-info → hatchet_sdk-1.16.1.dist-info}/METADATA +1 -1
- {hatchet_sdk-1.15.3.dist-info → hatchet_sdk-1.16.1.dist-info}/RECORD +36 -27
- {hatchet_sdk-1.15.3.dist-info → hatchet_sdk-1.16.1.dist-info}/WHEEL +0 -0
- {hatchet_sdk-1.15.3.dist-info → hatchet_sdk-1.16.1.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import BaseModel, Field
|
|
5
|
+
|
|
6
|
+
from hatchet_sdk.clients.rest.api.cel_api import CELApi
|
|
7
|
+
from hatchet_sdk.clients.rest.api_client import ApiClient
|
|
8
|
+
from hatchet_sdk.clients.rest.models.v1_cel_debug_request import V1CELDebugRequest
|
|
9
|
+
from hatchet_sdk.clients.rest.models.v1_cel_debug_response_status import (
|
|
10
|
+
V1CELDebugResponseStatus,
|
|
11
|
+
)
|
|
12
|
+
from hatchet_sdk.clients.v1.api_client import BaseRestClient, retry
|
|
13
|
+
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class CELSuccess(BaseModel):
|
|
17
|
+
status: Literal["success"] = "success"
|
|
18
|
+
output: bool
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CELFailure(BaseModel):
|
|
22
|
+
status: Literal["failure"] = "failure"
|
|
23
|
+
error: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class CELEvaluationResult(BaseModel):
|
|
27
|
+
result: CELSuccess | CELFailure = Field(discriminator="status")
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class CELClient(BaseRestClient):
|
|
31
|
+
"""
|
|
32
|
+
The CEL client is a client for debugging CEL expressions within Hatchet
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def _ca(self, client: ApiClient) -> CELApi:
|
|
36
|
+
return CELApi(client)
|
|
37
|
+
|
|
38
|
+
@retry
|
|
39
|
+
def debug(
|
|
40
|
+
self,
|
|
41
|
+
expression: str,
|
|
42
|
+
input: JSONSerializableMapping,
|
|
43
|
+
additional_metadata: JSONSerializableMapping | None = None,
|
|
44
|
+
filter_payload: JSONSerializableMapping | None = None,
|
|
45
|
+
) -> CELEvaluationResult:
|
|
46
|
+
"""
|
|
47
|
+
Debug a CEL expression with the provided input, filter payload, and optional metadata. Useful for testing and validating CEL expressions and debugging issues in production.
|
|
48
|
+
|
|
49
|
+
:param expression: The CEL expression to debug.
|
|
50
|
+
:param input: The input, which simulates the workflow run input.
|
|
51
|
+
:param additional_metadata: Additional metadata, which simulates metadata that could be sent with an event or a workflow run
|
|
52
|
+
:param filter_payload: The filter payload, which simulates a payload set on a previous-created filter
|
|
53
|
+
|
|
54
|
+
:raises ValueError: If no response is received from the CEL debug API.
|
|
55
|
+
|
|
56
|
+
:return: A V1CELDebugErrorResponse or V1CELDebugSuccessResponse containing the result of the debug operation.
|
|
57
|
+
"""
|
|
58
|
+
request = V1CELDebugRequest(
|
|
59
|
+
expression=expression,
|
|
60
|
+
input=input,
|
|
61
|
+
additionalMetadata=additional_metadata,
|
|
62
|
+
filterPayload=filter_payload,
|
|
63
|
+
)
|
|
64
|
+
with self.client() as client:
|
|
65
|
+
result = self._ca(client).v1_cel_debug(
|
|
66
|
+
tenant=self.client_config.tenant_id, v1_cel_debug_request=request
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
if result.status == V1CELDebugResponseStatus.ERROR:
|
|
70
|
+
if result.error is None:
|
|
71
|
+
raise ValueError("No error message received from CEL debug API.")
|
|
72
|
+
|
|
73
|
+
return CELEvaluationResult(result=CELFailure(error=result.error))
|
|
74
|
+
|
|
75
|
+
if result.output is None:
|
|
76
|
+
raise ValueError("No output received from CEL debug API.")
|
|
77
|
+
|
|
78
|
+
return CELEvaluationResult(result=CELSuccess(output=result.output))
|
|
79
|
+
|
|
80
|
+
async def aio_debug(
|
|
81
|
+
self,
|
|
82
|
+
expression: str,
|
|
83
|
+
input: JSONSerializableMapping,
|
|
84
|
+
additional_metadata: JSONSerializableMapping | None = None,
|
|
85
|
+
filter_payload: JSONSerializableMapping | None = None,
|
|
86
|
+
) -> CELEvaluationResult:
|
|
87
|
+
"""
|
|
88
|
+
Debug a CEL expression with the provided input, filter payload, and optional metadata. Useful for testing and validating CEL expressions and debugging issues in production.
|
|
89
|
+
|
|
90
|
+
:param expression: The CEL expression to debug.
|
|
91
|
+
:param input: The input, which simulates the workflow run input.
|
|
92
|
+
:param additional_metadata: Additional metadata, which simulates metadata that could be sent with an event or a workflow run
|
|
93
|
+
:param filter_payload: The filter payload, which simulates a payload set on a previous-created filter
|
|
94
|
+
|
|
95
|
+
:return: A V1CELDebugErrorResponse or V1CELDebugSuccessResponse containing the result of the debug operation.
|
|
96
|
+
"""
|
|
97
|
+
return await asyncio.to_thread(
|
|
98
|
+
self.debug, expression, input, additional_metadata, filter_payload
|
|
99
|
+
)
|
hatchet_sdk/features/runs.py
CHANGED
|
@@ -119,6 +119,26 @@ class RunsClient(BaseRestClient):
|
|
|
119
119
|
def _ta(self, client: ApiClient) -> TaskApi:
|
|
120
120
|
return TaskApi(client)
|
|
121
121
|
|
|
122
|
+
@retry
|
|
123
|
+
def get_task_run(self, task_run_id: str) -> V1TaskSummary:
|
|
124
|
+
"""
|
|
125
|
+
Get task run details for a given task run ID.
|
|
126
|
+
|
|
127
|
+
:param task_run_id: The ID of the task run to retrieve details for.
|
|
128
|
+
:return: Task run details for the specified task run ID.
|
|
129
|
+
"""
|
|
130
|
+
with self.client() as client:
|
|
131
|
+
return self._ta(client).v1_task_get(task_run_id)
|
|
132
|
+
|
|
133
|
+
async def aio_get_task_run(self, task_run_id: str) -> V1TaskSummary:
|
|
134
|
+
"""
|
|
135
|
+
Get task run details for a given task run ID.
|
|
136
|
+
|
|
137
|
+
:param task_run_id: The ID of the task run to retrieve details for.
|
|
138
|
+
:return: Task run details for the specified task run ID.
|
|
139
|
+
"""
|
|
140
|
+
return await asyncio.to_thread(self.get_task_run, task_run_id)
|
|
141
|
+
|
|
122
142
|
@retry
|
|
123
143
|
def get(self, workflow_run_id: str) -> V1WorkflowRunDetails:
|
|
124
144
|
"""
|
|
@@ -148,7 +168,7 @@ class RunsClient(BaseRestClient):
|
|
|
148
168
|
:return: The task status
|
|
149
169
|
"""
|
|
150
170
|
with self.client() as client:
|
|
151
|
-
return self._wra(client).v1_workflow_run_get_status(
|
|
171
|
+
return self._wra(client).v1_workflow_run_get_status(workflow_run_id)
|
|
152
172
|
|
|
153
173
|
async def aio_get_status(self, workflow_run_id: str) -> V1TaskStatus:
|
|
154
174
|
"""
|
hatchet_sdk/hatchet.py
CHANGED
|
@@ -12,6 +12,7 @@ from hatchet_sdk.clients.events import EventClient
|
|
|
12
12
|
from hatchet_sdk.clients.listeners.run_event_listener import RunEventListenerClient
|
|
13
13
|
from hatchet_sdk.clients.rest.models.tenant_version import TenantVersion
|
|
14
14
|
from hatchet_sdk.config import ClientConfig
|
|
15
|
+
from hatchet_sdk.features.cel import CELClient
|
|
15
16
|
from hatchet_sdk.features.cron import CronClient
|
|
16
17
|
from hatchet_sdk.features.filters import FiltersClient
|
|
17
18
|
from hatchet_sdk.features.logs import LogsClient
|
|
@@ -66,6 +67,13 @@ class Hatchet:
|
|
|
66
67
|
"🚨⚠️‼️ YOU ARE USING A V0 ENGINE WITH A V1 SDK, WHICH IS NOT SUPPORTED. PLEASE UPGRADE YOUR ENGINE TO V1.🚨⚠️‼️"
|
|
67
68
|
)
|
|
68
69
|
|
|
70
|
+
@property
|
|
71
|
+
def cel(self) -> CELClient:
|
|
72
|
+
"""
|
|
73
|
+
The CEL client is a client for interacting with Hatchet's CEL API.
|
|
74
|
+
"""
|
|
75
|
+
return self._client.cel
|
|
76
|
+
|
|
69
77
|
@property
|
|
70
78
|
def cron(self) -> CronClient:
|
|
71
79
|
"""
|
|
@@ -64,7 +64,7 @@ OTEL_TRACEPARENT_KEY = "traceparent"
|
|
|
64
64
|
|
|
65
65
|
def create_traceparent() -> str | None:
|
|
66
66
|
logger.warning(
|
|
67
|
-
"
|
|
67
|
+
"as of SDK version 1.11.0, you no longer need to call `create_traceparent` manually. The traceparent will be automatically created by the instrumentor and injected into the metadata of actions and events when appropriate. This method will be removed in a future version.",
|
|
68
68
|
)
|
|
69
69
|
return _create_traceparent()
|
|
70
70
|
|
|
@@ -91,7 +91,7 @@ def parse_carrier_from_metadata(
|
|
|
91
91
|
metadata: JSONSerializableMapping | None,
|
|
92
92
|
) -> Context | None:
|
|
93
93
|
logger.warning(
|
|
94
|
-
"
|
|
94
|
+
"as of SDK version 1.11.0, you no longer need to call `parse_carrier_from_metadata` manually. This method will be removed in a future version.",
|
|
95
95
|
)
|
|
96
96
|
|
|
97
97
|
return _parse_carrier_from_metadata(metadata)
|
|
@@ -133,7 +133,7 @@ def inject_traceparent_into_metadata(
|
|
|
133
133
|
metadata: dict[str, str], traceparent: str | None = None
|
|
134
134
|
) -> dict[str, str]:
|
|
135
135
|
logger.warning(
|
|
136
|
-
"
|
|
136
|
+
"as of SDK version 1.11.0, you no longer need to call `inject_traceparent_into_metadata` manually. The traceparent will automatically be injected by the instrumentor. This method will be removed in a future version.",
|
|
137
137
|
)
|
|
138
138
|
|
|
139
139
|
return _inject_traceparent_into_metadata(metadata, traceparent)
|
|
@@ -4,6 +4,7 @@ from collections import Counter
|
|
|
4
4
|
from contextvars import ContextVar
|
|
5
5
|
|
|
6
6
|
from hatchet_sdk.runnables.action import ActionKey
|
|
7
|
+
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
|
7
8
|
|
|
8
9
|
ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
|
|
9
10
|
"ctx_workflow_run_id", default=None
|
|
@@ -13,6 +14,9 @@ ctx_action_key: ContextVar[ActionKey | None] = ContextVar(
|
|
|
13
14
|
)
|
|
14
15
|
ctx_step_run_id: ContextVar[str | None] = ContextVar("ctx_step_run_id", default=None)
|
|
15
16
|
ctx_worker_id: ContextVar[str | None] = ContextVar("ctx_worker_id", default=None)
|
|
17
|
+
ctx_additional_metadata: ContextVar[JSONSerializableMapping | None] = ContextVar(
|
|
18
|
+
"ctx_additional_metadata", default=None
|
|
19
|
+
)
|
|
16
20
|
|
|
17
21
|
workflow_spawn_indices = Counter[ActionKey]()
|
|
18
22
|
spawn_index_lock = asyncio.Lock()
|
hatchet_sdk/runnables/task.py
CHANGED
|
@@ -11,6 +11,7 @@ from hatchet_sdk.conditions import (
|
|
|
11
11
|
flatten_conditions,
|
|
12
12
|
)
|
|
13
13
|
from hatchet_sdk.context.context import Context, DurableContext
|
|
14
|
+
from hatchet_sdk.context.worker_context import WorkerContext
|
|
14
15
|
from hatchet_sdk.contracts.v1.shared.condition_pb2 import TaskConditions
|
|
15
16
|
from hatchet_sdk.contracts.v1.workflows_pb2 import (
|
|
16
17
|
CreateTaskOpts,
|
|
@@ -19,6 +20,7 @@ from hatchet_sdk.contracts.v1.workflows_pb2 import (
|
|
|
19
20
|
)
|
|
20
21
|
from hatchet_sdk.runnables.types import (
|
|
21
22
|
ConcurrencyExpression,
|
|
23
|
+
EmptyModel,
|
|
22
24
|
R,
|
|
23
25
|
StepType,
|
|
24
26
|
TWorkflowInput,
|
|
@@ -30,9 +32,11 @@ from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_exp
|
|
|
30
32
|
from hatchet_sdk.utils.typing import (
|
|
31
33
|
AwaitableLike,
|
|
32
34
|
CoroutineLike,
|
|
35
|
+
JSONSerializableMapping,
|
|
33
36
|
TaskIOValidator,
|
|
34
37
|
is_basemodel_subclass,
|
|
35
38
|
)
|
|
39
|
+
from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender
|
|
36
40
|
|
|
37
41
|
if TYPE_CHECKING:
|
|
38
42
|
from hatchet_sdk.runnables.workflow import Workflow
|
|
@@ -186,3 +190,132 @@ class Task(Generic[TWorkflowInput, R]):
|
|
|
186
190
|
sleep_conditions=sleep_conditions,
|
|
187
191
|
user_event_conditions=user_events,
|
|
188
192
|
)
|
|
193
|
+
|
|
194
|
+
def _create_mock_context(
|
|
195
|
+
self,
|
|
196
|
+
input: TWorkflowInput | None,
|
|
197
|
+
additional_metadata: JSONSerializableMapping | None = None,
|
|
198
|
+
parent_outputs: dict[str, JSONSerializableMapping] | None = None,
|
|
199
|
+
retry_count: int = 0,
|
|
200
|
+
lifespan_context: Any = None,
|
|
201
|
+
) -> Context | DurableContext:
|
|
202
|
+
from hatchet_sdk.runnables.action import Action, ActionPayload, ActionType
|
|
203
|
+
|
|
204
|
+
additional_metadata = additional_metadata or {}
|
|
205
|
+
parent_outputs = parent_outputs or {}
|
|
206
|
+
|
|
207
|
+
if input is None:
|
|
208
|
+
input = cast(TWorkflowInput, EmptyModel())
|
|
209
|
+
|
|
210
|
+
action_payload = ActionPayload(input=input.model_dump(), parents=parent_outputs)
|
|
211
|
+
|
|
212
|
+
action = Action(
|
|
213
|
+
tenant_id=self.workflow.client.config.tenant_id,
|
|
214
|
+
worker_id="mock-worker-id",
|
|
215
|
+
workflow_run_id="mock-workflow-run-id",
|
|
216
|
+
get_group_key_run_id="mock-get-group-key-run-id",
|
|
217
|
+
job_id="mock-job-id",
|
|
218
|
+
job_name="mock-job-name",
|
|
219
|
+
job_run_id="mock-job-run-id",
|
|
220
|
+
step_id="mock-step-id",
|
|
221
|
+
step_run_id="mock-step-run-id",
|
|
222
|
+
action_id="mock:action",
|
|
223
|
+
action_payload=action_payload,
|
|
224
|
+
action_type=ActionType.START_STEP_RUN,
|
|
225
|
+
retry_count=retry_count,
|
|
226
|
+
additional_metadata=additional_metadata,
|
|
227
|
+
child_workflow_index=None,
|
|
228
|
+
child_workflow_key=None,
|
|
229
|
+
parent_workflow_run_id=None,
|
|
230
|
+
priority=1,
|
|
231
|
+
workflow_version_id="mock-workflow-version-id",
|
|
232
|
+
workflow_id="mock-workflow-id",
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
constructor = DurableContext if self.is_durable else Context
|
|
236
|
+
|
|
237
|
+
return constructor(
|
|
238
|
+
action=action,
|
|
239
|
+
dispatcher_client=self.workflow.client._client.dispatcher,
|
|
240
|
+
admin_client=self.workflow.client._client.admin,
|
|
241
|
+
event_client=self.workflow.client._client.event,
|
|
242
|
+
durable_event_listener=None,
|
|
243
|
+
worker=WorkerContext(
|
|
244
|
+
labels={}, client=self.workflow.client._client.dispatcher
|
|
245
|
+
),
|
|
246
|
+
runs_client=self.workflow.client._client.runs,
|
|
247
|
+
lifespan_context=lifespan_context,
|
|
248
|
+
log_sender=AsyncLogSender(self.workflow.client._client.event),
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
def mock_run(
|
|
252
|
+
self,
|
|
253
|
+
input: TWorkflowInput | None = None,
|
|
254
|
+
additional_metadata: JSONSerializableMapping | None = None,
|
|
255
|
+
parent_outputs: dict[str, JSONSerializableMapping] | None = None,
|
|
256
|
+
retry_count: int = 0,
|
|
257
|
+
lifespan: Any = None,
|
|
258
|
+
) -> R:
|
|
259
|
+
"""
|
|
260
|
+
Mimic the execution of a task. This method is intended to be used to unit test
|
|
261
|
+
tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
|
|
262
|
+
tasks and `aio_mock_run` for async tasks.
|
|
263
|
+
|
|
264
|
+
:param input: The input to the task.
|
|
265
|
+
:param additional_metadata: Additional metadata to attach to the task.
|
|
266
|
+
:param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
|
|
267
|
+
:param retry_count: The number of times the task has been retried.
|
|
268
|
+
:param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
|
|
269
|
+
|
|
270
|
+
:return: The output of the task.
|
|
271
|
+
:raises TypeError: If the task is an async function and `mock_run` is called, or if the task is a sync function and `aio_mock_run` is called.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
if self.is_async_function:
|
|
275
|
+
raise TypeError(
|
|
276
|
+
f"{self.name} is not a sync function. Use `aio_mock_run` instead."
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
ctx = self._create_mock_context(
|
|
280
|
+
input, additional_metadata, parent_outputs, retry_count, lifespan
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
return self.call(ctx)
|
|
284
|
+
|
|
285
|
+
async def aio_mock_run(
|
|
286
|
+
self,
|
|
287
|
+
input: TWorkflowInput | None = None,
|
|
288
|
+
additional_metadata: JSONSerializableMapping | None = None,
|
|
289
|
+
parent_outputs: dict[str, JSONSerializableMapping] | None = None,
|
|
290
|
+
retry_count: int = 0,
|
|
291
|
+
lifespan: Any = None,
|
|
292
|
+
) -> R:
|
|
293
|
+
"""
|
|
294
|
+
Mimic the execution of a task. This method is intended to be used to unit test
|
|
295
|
+
tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
|
|
296
|
+
tasks and `aio_mock_run` for async tasks.
|
|
297
|
+
|
|
298
|
+
:param input: The input to the task.
|
|
299
|
+
:param additional_metadata: Additional metadata to attach to the task.
|
|
300
|
+
:param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
|
|
301
|
+
:param retry_count: The number of times the task has been retried.
|
|
302
|
+
:param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
|
|
303
|
+
|
|
304
|
+
:return: The output of the task.
|
|
305
|
+
:raises TypeError: If the task is an async function and `mock_run` is called, or if the task is a sync function and `aio_mock_run` is called.
|
|
306
|
+
"""
|
|
307
|
+
|
|
308
|
+
if not self.is_async_function:
|
|
309
|
+
raise TypeError(
|
|
310
|
+
f"{self.name} is not an async function. Use `mock_run` instead."
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
ctx = self._create_mock_context(
|
|
314
|
+
input,
|
|
315
|
+
additional_metadata,
|
|
316
|
+
parent_outputs,
|
|
317
|
+
retry_count,
|
|
318
|
+
lifespan,
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
return await self.aio_call(ctx)
|
|
@@ -2,7 +2,16 @@ import asyncio
|
|
|
2
2
|
from collections.abc import Callable
|
|
3
3
|
from datetime import datetime, timedelta
|
|
4
4
|
from functools import cached_property
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import (
|
|
6
|
+
TYPE_CHECKING,
|
|
7
|
+
Any,
|
|
8
|
+
Generic,
|
|
9
|
+
Literal,
|
|
10
|
+
TypeVar,
|
|
11
|
+
cast,
|
|
12
|
+
get_type_hints,
|
|
13
|
+
overload,
|
|
14
|
+
)
|
|
6
15
|
|
|
7
16
|
from google.protobuf import timestamp_pb2
|
|
8
17
|
from pydantic import BaseModel, model_validator
|
|
@@ -651,39 +660,83 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
|
|
651
660
|
|
|
652
661
|
return await ref.aio_result()
|
|
653
662
|
|
|
663
|
+
def _get_result(
|
|
664
|
+
self, ref: WorkflowRunRef, return_exceptions: bool
|
|
665
|
+
) -> dict[str, Any] | BaseException:
|
|
666
|
+
try:
|
|
667
|
+
return ref.result()
|
|
668
|
+
except Exception as e:
|
|
669
|
+
if return_exceptions:
|
|
670
|
+
return e
|
|
671
|
+
raise e
|
|
672
|
+
|
|
673
|
+
@overload
|
|
674
|
+
def run_many(
|
|
675
|
+
self,
|
|
676
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
677
|
+
return_exceptions: Literal[True],
|
|
678
|
+
) -> list[dict[str, Any] | BaseException]: ...
|
|
679
|
+
|
|
680
|
+
@overload
|
|
654
681
|
def run_many(
|
|
655
682
|
self,
|
|
656
683
|
workflows: list[WorkflowRunTriggerConfig],
|
|
657
|
-
|
|
684
|
+
return_exceptions: Literal[False] = False,
|
|
685
|
+
) -> list[dict[str, Any]]: ...
|
|
686
|
+
|
|
687
|
+
def run_many(
|
|
688
|
+
self,
|
|
689
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
690
|
+
return_exceptions: bool = False,
|
|
691
|
+
) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
|
|
658
692
|
"""
|
|
659
693
|
Run a workflow in bulk and wait for all runs to complete.
|
|
660
694
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
661
695
|
|
|
662
696
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
697
|
+
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
663
698
|
:returns: A list of results for each workflow run.
|
|
664
699
|
"""
|
|
665
700
|
refs = self.client._client.admin.run_workflows(
|
|
666
701
|
workflows=workflows,
|
|
667
702
|
)
|
|
668
703
|
|
|
669
|
-
return [
|
|
704
|
+
return [self._get_result(ref, return_exceptions) for ref in refs]
|
|
670
705
|
|
|
706
|
+
@overload
|
|
671
707
|
async def aio_run_many(
|
|
672
708
|
self,
|
|
673
709
|
workflows: list[WorkflowRunTriggerConfig],
|
|
674
|
-
|
|
710
|
+
return_exceptions: Literal[True],
|
|
711
|
+
) -> list[dict[str, Any] | BaseException]: ...
|
|
712
|
+
|
|
713
|
+
@overload
|
|
714
|
+
async def aio_run_many(
|
|
715
|
+
self,
|
|
716
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
717
|
+
return_exceptions: Literal[False] = False,
|
|
718
|
+
) -> list[dict[str, Any]]: ...
|
|
719
|
+
|
|
720
|
+
async def aio_run_many(
|
|
721
|
+
self,
|
|
722
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
723
|
+
return_exceptions: bool = False,
|
|
724
|
+
) -> list[dict[str, Any]] | list[dict[str, Any] | BaseException]:
|
|
675
725
|
"""
|
|
676
726
|
Run a workflow in bulk and wait for all runs to complete.
|
|
677
727
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
678
728
|
|
|
679
729
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
730
|
+
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
680
731
|
:returns: A list of results for each workflow run.
|
|
681
732
|
"""
|
|
682
733
|
refs = await self.client._client.admin.aio_run_workflows(
|
|
683
734
|
workflows=workflows,
|
|
684
735
|
)
|
|
685
736
|
|
|
686
|
-
return await asyncio.gather(
|
|
737
|
+
return await asyncio.gather(
|
|
738
|
+
*[ref.aio_result() for ref in refs], return_exceptions=return_exceptions
|
|
739
|
+
)
|
|
687
740
|
|
|
688
741
|
def run_many_no_wait(
|
|
689
742
|
self,
|
|
@@ -946,7 +999,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
|
|
|
946
999
|
|
|
947
1000
|
:param backoff_max_seconds: The maximum number of seconds to allow retries with exponential backoff to continue.
|
|
948
1001
|
|
|
949
|
-
:param concurrency: A list of concurrency expressions for the on-
|
|
1002
|
+
:param concurrency: A list of concurrency expressions for the on-failure task.
|
|
950
1003
|
|
|
951
1004
|
:returns: A decorator which creates a `Task` object.
|
|
952
1005
|
"""
|
|
@@ -1137,7 +1190,18 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
|
|
1137
1190
|
|
|
1138
1191
|
self.config = self._workflow.config
|
|
1139
1192
|
|
|
1140
|
-
|
|
1193
|
+
@overload
|
|
1194
|
+
def _extract_result(self, result: dict[str, Any]) -> R: ...
|
|
1195
|
+
|
|
1196
|
+
@overload
|
|
1197
|
+
def _extract_result(self, result: BaseException) -> BaseException: ...
|
|
1198
|
+
|
|
1199
|
+
def _extract_result(
|
|
1200
|
+
self, result: dict[str, Any] | BaseException
|
|
1201
|
+
) -> R | BaseException:
|
|
1202
|
+
if isinstance(result, BaseException):
|
|
1203
|
+
return result
|
|
1204
|
+
|
|
1141
1205
|
output = result.get(self._task.name)
|
|
1142
1206
|
|
|
1143
1207
|
if not self._output_validator:
|
|
@@ -1217,30 +1281,72 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
|
|
1217
1281
|
|
|
1218
1282
|
return TaskRunRef[TWorkflowInput, R](self, ref)
|
|
1219
1283
|
|
|
1220
|
-
|
|
1284
|
+
@overload
|
|
1285
|
+
def run_many(
|
|
1286
|
+
self,
|
|
1287
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
1288
|
+
return_exceptions: Literal[True],
|
|
1289
|
+
) -> list[R | BaseException]: ...
|
|
1290
|
+
|
|
1291
|
+
@overload
|
|
1292
|
+
def run_many(
|
|
1293
|
+
self,
|
|
1294
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
1295
|
+
return_exceptions: Literal[False] = False,
|
|
1296
|
+
) -> list[R]: ...
|
|
1297
|
+
|
|
1298
|
+
def run_many(
|
|
1299
|
+
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
|
|
1300
|
+
) -> list[R] | list[R | BaseException]:
|
|
1221
1301
|
"""
|
|
1222
1302
|
Run a workflow in bulk and wait for all runs to complete.
|
|
1223
1303
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
1224
1304
|
|
|
1225
1305
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
1306
|
+
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
1226
1307
|
:returns: A list of results for each workflow run.
|
|
1227
1308
|
"""
|
|
1228
1309
|
return [
|
|
1229
1310
|
self._extract_result(result)
|
|
1230
|
-
for result in self._workflow.run_many(
|
|
1311
|
+
for result in self._workflow.run_many(
|
|
1312
|
+
workflows,
|
|
1313
|
+
## hack: typing needs literal
|
|
1314
|
+
True if return_exceptions else False, # noqa: SIM210
|
|
1315
|
+
)
|
|
1231
1316
|
]
|
|
1232
1317
|
|
|
1233
|
-
|
|
1318
|
+
@overload
|
|
1319
|
+
async def aio_run_many(
|
|
1320
|
+
self,
|
|
1321
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
1322
|
+
return_exceptions: Literal[True],
|
|
1323
|
+
) -> list[R | BaseException]: ...
|
|
1324
|
+
|
|
1325
|
+
@overload
|
|
1326
|
+
async def aio_run_many(
|
|
1327
|
+
self,
|
|
1328
|
+
workflows: list[WorkflowRunTriggerConfig],
|
|
1329
|
+
return_exceptions: Literal[False] = False,
|
|
1330
|
+
) -> list[R]: ...
|
|
1331
|
+
|
|
1332
|
+
async def aio_run_many(
|
|
1333
|
+
self, workflows: list[WorkflowRunTriggerConfig], return_exceptions: bool = False
|
|
1334
|
+
) -> list[R] | list[R | BaseException]:
|
|
1234
1335
|
"""
|
|
1235
1336
|
Run a workflow in bulk and wait for all runs to complete.
|
|
1236
1337
|
This method triggers multiple workflow runs, blocks until all of them complete, and returns the final results.
|
|
1237
1338
|
|
|
1238
1339
|
:param workflows: A list of `WorkflowRunTriggerConfig` objects, each representing a workflow run to be triggered.
|
|
1340
|
+
:param return_exceptions: If `True`, exceptions will be returned as part of the results instead of raising them.
|
|
1239
1341
|
:returns: A list of results for each workflow run.
|
|
1240
1342
|
"""
|
|
1241
1343
|
return [
|
|
1242
1344
|
self._extract_result(result)
|
|
1243
|
-
for result in await self._workflow.aio_run_many(
|
|
1345
|
+
for result in await self._workflow.aio_run_many(
|
|
1346
|
+
workflows,
|
|
1347
|
+
## hack: typing needs literal
|
|
1348
|
+
True if return_exceptions else False, # noqa: SIM210
|
|
1349
|
+
)
|
|
1244
1350
|
]
|
|
1245
1351
|
|
|
1246
1352
|
def run_many_no_wait(
|
|
@@ -1273,3 +1379,104 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
|
|
1273
1379
|
refs = await self._workflow.aio_run_many_no_wait(workflows)
|
|
1274
1380
|
|
|
1275
1381
|
return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
|
|
1382
|
+
|
|
1383
|
+
def mock_run(
|
|
1384
|
+
self,
|
|
1385
|
+
input: TWorkflowInput | None = None,
|
|
1386
|
+
additional_metadata: JSONSerializableMapping | None = None,
|
|
1387
|
+
parent_outputs: dict[str, JSONSerializableMapping] | None = None,
|
|
1388
|
+
retry_count: int = 0,
|
|
1389
|
+
lifespan: Any = None,
|
|
1390
|
+
) -> R:
|
|
1391
|
+
"""
|
|
1392
|
+
Mimic the execution of a task. This method is intended to be used to unit test
|
|
1393
|
+
tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
|
|
1394
|
+
tasks and `aio_mock_run` for async tasks.
|
|
1395
|
+
|
|
1396
|
+
:param input: The input to the task.
|
|
1397
|
+
:param additional_metadata: Additional metadata to attach to the task.
|
|
1398
|
+
:param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
|
|
1399
|
+
:param retry_count: The number of times the task has been retried.
|
|
1400
|
+
:param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
|
|
1401
|
+
|
|
1402
|
+
:return: The output of the task.
|
|
1403
|
+
"""
|
|
1404
|
+
|
|
1405
|
+
return self._task.mock_run(
|
|
1406
|
+
input=input,
|
|
1407
|
+
additional_metadata=additional_metadata,
|
|
1408
|
+
parent_outputs=parent_outputs,
|
|
1409
|
+
retry_count=retry_count,
|
|
1410
|
+
lifespan=lifespan,
|
|
1411
|
+
)
|
|
1412
|
+
|
|
1413
|
+
async def aio_mock_run(
|
|
1414
|
+
self,
|
|
1415
|
+
input: TWorkflowInput | None = None,
|
|
1416
|
+
additional_metadata: JSONSerializableMapping | None = None,
|
|
1417
|
+
parent_outputs: dict[str, JSONSerializableMapping] | None = None,
|
|
1418
|
+
retry_count: int = 0,
|
|
1419
|
+
lifespan: Any = None,
|
|
1420
|
+
) -> R:
|
|
1421
|
+
"""
|
|
1422
|
+
Mimic the execution of a task. This method is intended to be used to unit test
|
|
1423
|
+
tasks without needing to interact with the Hatchet engine. Use `mock_run` for sync
|
|
1424
|
+
tasks and `aio_mock_run` for async tasks.
|
|
1425
|
+
|
|
1426
|
+
:param input: The input to the task.
|
|
1427
|
+
:param additional_metadata: Additional metadata to attach to the task.
|
|
1428
|
+
:param parent_outputs: Outputs from parent tasks, if any. This is useful for mimicking DAG functionality. For instance, if you have a task `step_2` that has a `parent` which is `step_1`, you can pass `parent_outputs={"step_1": {"result": "Hello, world!"}}` to `step_2.mock_run()` to be able to access `ctx.task_output(step_1)` in `step_2`.
|
|
1429
|
+
:param retry_count: The number of times the task has been retried.
|
|
1430
|
+
:param lifespan: The lifespan to be used in the task, which is useful if one was set on the worker. This will allow you to access `ctx.lifespan` inside of your task.
|
|
1431
|
+
|
|
1432
|
+
:return: The output of the task.
|
|
1433
|
+
"""
|
|
1434
|
+
|
|
1435
|
+
return await self._task.aio_mock_run(
|
|
1436
|
+
input=input,
|
|
1437
|
+
additional_metadata=additional_metadata,
|
|
1438
|
+
parent_outputs=parent_outputs,
|
|
1439
|
+
retry_count=retry_count,
|
|
1440
|
+
lifespan=lifespan,
|
|
1441
|
+
)
|
|
1442
|
+
|
|
1443
|
+
@property
|
|
1444
|
+
def is_async_function(self) -> bool:
|
|
1445
|
+
"""
|
|
1446
|
+
Check if the task is an async function.
|
|
1447
|
+
|
|
1448
|
+
:returns: True if the task is an async function, False otherwise.
|
|
1449
|
+
"""
|
|
1450
|
+
return self._task.is_async_function
|
|
1451
|
+
|
|
1452
|
+
def get_run_ref(self, run_id: str) -> TaskRunRef[TWorkflowInput, R]:
|
|
1453
|
+
"""
|
|
1454
|
+
Get a reference to a task run by its run ID.
|
|
1455
|
+
|
|
1456
|
+
:param run_id: The ID of the run to get the reference for.
|
|
1457
|
+
:returns: A `TaskRunRef` object representing the reference to the task run.
|
|
1458
|
+
"""
|
|
1459
|
+
wrr = self._workflow.client._client.runs.get_run_ref(run_id)
|
|
1460
|
+
return TaskRunRef[TWorkflowInput, R](self, wrr)
|
|
1461
|
+
|
|
1462
|
+
async def aio_get_result(self, run_id: str) -> R:
|
|
1463
|
+
"""
|
|
1464
|
+
Get the result of a task run by its run ID.
|
|
1465
|
+
|
|
1466
|
+
:param run_id: The ID of the run to get the result for.
|
|
1467
|
+
:returns: The result of the task run.
|
|
1468
|
+
"""
|
|
1469
|
+
run_ref = self.get_run_ref(run_id)
|
|
1470
|
+
|
|
1471
|
+
return await run_ref.aio_result()
|
|
1472
|
+
|
|
1473
|
+
def get_result(self, run_id: str) -> R:
|
|
1474
|
+
"""
|
|
1475
|
+
Get the result of a task run by its run ID.
|
|
1476
|
+
|
|
1477
|
+
:param run_id: The ID of the run to get the result for.
|
|
1478
|
+
:returns: The result of the task run.
|
|
1479
|
+
"""
|
|
1480
|
+
run_ref = self.get_run_ref(run_id)
|
|
1481
|
+
|
|
1482
|
+
return run_ref.result()
|