hatchet-sdk 1.0.0__py3-none-any.whl → 1.0.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hatchet-sdk might be problematic. Click here for more details.
- hatchet_sdk/__init__.py +27 -16
- hatchet_sdk/client.py +13 -63
- hatchet_sdk/clients/admin.py +203 -124
- hatchet_sdk/clients/dispatcher/action_listener.py +42 -42
- hatchet_sdk/clients/dispatcher/dispatcher.py +18 -16
- hatchet_sdk/clients/durable_event_listener.py +327 -0
- hatchet_sdk/clients/rest/__init__.py +12 -1
- hatchet_sdk/clients/rest/api/log_api.py +258 -0
- hatchet_sdk/clients/rest/api/task_api.py +32 -6
- hatchet_sdk/clients/rest/api/workflow_runs_api.py +626 -0
- hatchet_sdk/clients/rest/models/__init__.py +12 -1
- hatchet_sdk/clients/rest/models/v1_log_line.py +94 -0
- hatchet_sdk/clients/rest/models/v1_log_line_level.py +39 -0
- hatchet_sdk/clients/rest/models/v1_log_line_list.py +110 -0
- hatchet_sdk/clients/rest/models/v1_task_summary.py +80 -64
- hatchet_sdk/clients/rest/models/v1_trigger_workflow_run_request.py +95 -0
- hatchet_sdk/clients/rest/models/v1_workflow_run_display_name.py +98 -0
- hatchet_sdk/clients/rest/models/v1_workflow_run_display_name_list.py +114 -0
- hatchet_sdk/clients/rest/models/workflow_run_shape_item_for_workflow_run_details.py +9 -4
- hatchet_sdk/clients/rest_client.py +21 -0
- hatchet_sdk/clients/run_event_listener.py +0 -1
- hatchet_sdk/context/context.py +85 -147
- hatchet_sdk/contracts/dispatcher_pb2_grpc.py +1 -1
- hatchet_sdk/contracts/events_pb2.py +2 -2
- hatchet_sdk/contracts/events_pb2_grpc.py +1 -1
- hatchet_sdk/contracts/v1/dispatcher_pb2.py +36 -0
- hatchet_sdk/contracts/v1/dispatcher_pb2.pyi +38 -0
- hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +145 -0
- hatchet_sdk/contracts/v1/shared/condition_pb2.py +39 -0
- hatchet_sdk/contracts/v1/shared/condition_pb2.pyi +72 -0
- hatchet_sdk/contracts/v1/shared/condition_pb2_grpc.py +29 -0
- hatchet_sdk/contracts/v1/workflows_pb2.py +67 -0
- hatchet_sdk/contracts/v1/workflows_pb2.pyi +228 -0
- hatchet_sdk/contracts/v1/workflows_pb2_grpc.py +234 -0
- hatchet_sdk/contracts/workflows_pb2_grpc.py +1 -1
- hatchet_sdk/features/cron.py +3 -3
- hatchet_sdk/features/scheduled.py +2 -2
- hatchet_sdk/hatchet.py +427 -151
- hatchet_sdk/opentelemetry/instrumentor.py +8 -13
- hatchet_sdk/rate_limit.py +33 -39
- hatchet_sdk/runnables/contextvars.py +12 -0
- hatchet_sdk/runnables/standalone.py +194 -0
- hatchet_sdk/runnables/task.py +144 -0
- hatchet_sdk/runnables/types.py +138 -0
- hatchet_sdk/runnables/workflow.py +764 -0
- hatchet_sdk/utils/aio_utils.py +0 -79
- hatchet_sdk/utils/proto_enums.py +0 -7
- hatchet_sdk/utils/timedelta_to_expression.py +23 -0
- hatchet_sdk/utils/typing.py +2 -2
- hatchet_sdk/v0/clients/rest_client.py +9 -0
- hatchet_sdk/v0/worker/action_listener_process.py +18 -2
- hatchet_sdk/waits.py +120 -0
- hatchet_sdk/worker/action_listener_process.py +64 -30
- hatchet_sdk/worker/runner/run_loop_manager.py +35 -25
- hatchet_sdk/worker/runner/runner.py +72 -49
- hatchet_sdk/worker/runner/utils/capture_logs.py +3 -11
- hatchet_sdk/worker/worker.py +155 -118
- hatchet_sdk/workflow_run.py +4 -5
- {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/METADATA +1 -2
- {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/RECORD +62 -42
- {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/entry_points.txt +2 -0
- hatchet_sdk/semver.py +0 -30
- hatchet_sdk/worker/runner/utils/error_with_traceback.py +0 -6
- hatchet_sdk/workflow.py +0 -527
- {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/WHEEL +0 -0
|
@@ -30,7 +30,7 @@ import hatchet_sdk
|
|
|
30
30
|
from hatchet_sdk.clients.admin import (
|
|
31
31
|
AdminClient,
|
|
32
32
|
TriggerWorkflowOptions,
|
|
33
|
-
|
|
33
|
+
WorkflowRunTriggerConfig,
|
|
34
34
|
)
|
|
35
35
|
from hatchet_sdk.clients.dispatcher.action_listener import Action
|
|
36
36
|
from hatchet_sdk.clients.events import (
|
|
@@ -352,14 +352,12 @@ class HatchetInstrumentor(BaseInstrumentor): # type: ignore[misc]
|
|
|
352
352
|
def _wrap_run_workflows(
|
|
353
353
|
self,
|
|
354
354
|
wrapped: Callable[
|
|
355
|
-
[list[
|
|
355
|
+
[list[WorkflowRunTriggerConfig]],
|
|
356
|
+
list[WorkflowRunRef],
|
|
356
357
|
],
|
|
357
358
|
instance: AdminClient,
|
|
358
|
-
args: tuple[
|
|
359
|
-
|
|
360
|
-
TriggerWorkflowOptions | None,
|
|
361
|
-
],
|
|
362
|
-
kwargs: dict[str, list[WorkflowRunDict] | TriggerWorkflowOptions | None],
|
|
359
|
+
args: tuple[list[WorkflowRunTriggerConfig],],
|
|
360
|
+
kwargs: dict[str, list[WorkflowRunTriggerConfig]],
|
|
363
361
|
) -> list[WorkflowRunRef]:
|
|
364
362
|
with self._tracer.start_as_current_span(
|
|
365
363
|
"hatchet.run_workflows",
|
|
@@ -370,15 +368,12 @@ class HatchetInstrumentor(BaseInstrumentor): # type: ignore[misc]
|
|
|
370
368
|
async def _wrap_async_run_workflows(
|
|
371
369
|
self,
|
|
372
370
|
wrapped: Callable[
|
|
373
|
-
[list[
|
|
371
|
+
[list[WorkflowRunTriggerConfig]],
|
|
374
372
|
Coroutine[None, None, list[WorkflowRunRef]],
|
|
375
373
|
],
|
|
376
374
|
instance: AdminClient,
|
|
377
|
-
args: tuple[
|
|
378
|
-
|
|
379
|
-
TriggerWorkflowOptions | None,
|
|
380
|
-
],
|
|
381
|
-
kwargs: dict[str, list[WorkflowRunDict] | TriggerWorkflowOptions | None],
|
|
375
|
+
args: tuple[list[WorkflowRunTriggerConfig],],
|
|
376
|
+
kwargs: dict[str, list[WorkflowRunTriggerConfig]],
|
|
382
377
|
) -> list[WorkflowRunRef]:
|
|
383
378
|
with self._tracer.start_as_current_span(
|
|
384
379
|
"hatchet.run_workflows",
|
hatchet_sdk/rate_limit.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
from dataclasses import dataclass
|
|
2
1
|
from enum import Enum
|
|
3
2
|
|
|
4
3
|
from celpy import CELEvalError, Environment # type: ignore
|
|
4
|
+
from pydantic import BaseModel, model_validator
|
|
5
5
|
|
|
6
|
-
from hatchet_sdk.contracts.workflows_pb2 import
|
|
6
|
+
from hatchet_sdk.contracts.v1.workflows_pb2 import CreateTaskRateLimit
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
def validate_cel_expression(expr: str) -> bool:
|
|
@@ -25,8 +25,7 @@ class RateLimitDuration(str, Enum):
|
|
|
25
25
|
YEAR = "YEAR"
|
|
26
26
|
|
|
27
27
|
|
|
28
|
-
|
|
29
|
-
class RateLimit:
|
|
28
|
+
class RateLimit(BaseModel):
|
|
30
29
|
"""
|
|
31
30
|
Represents a rate limit configuration for a step in a workflow.
|
|
32
31
|
|
|
@@ -68,46 +67,41 @@ class RateLimit:
|
|
|
68
67
|
limit: int | str | None = None
|
|
69
68
|
duration: RateLimitDuration = RateLimitDuration.MINUTE
|
|
70
69
|
|
|
71
|
-
|
|
70
|
+
@model_validator(mode="after")
|
|
71
|
+
def validate_rate_limit(self) -> "RateLimit":
|
|
72
|
+
if self.dynamic_key and self.static_key:
|
|
73
|
+
raise ValueError("Cannot have both static key and dynamic key set")
|
|
72
74
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
+
if self.dynamic_key and not validate_cel_expression(self.dynamic_key):
|
|
76
|
+
raise ValueError(f"Invalid CEL expression: {self.dynamic_key}")
|
|
77
|
+
|
|
78
|
+
if not isinstance(self.units, int) and not validate_cel_expression(self.units):
|
|
79
|
+
raise ValueError(f"Invalid CEL expression: {self.units}")
|
|
80
|
+
|
|
81
|
+
if (
|
|
82
|
+
self.limit
|
|
83
|
+
and not isinstance(self.limit, int)
|
|
84
|
+
and not validate_cel_expression(self.limit)
|
|
85
|
+
):
|
|
86
|
+
raise ValueError(f"Invalid CEL expression: {self.limit}")
|
|
87
|
+
|
|
88
|
+
if self.dynamic_key and not self.limit:
|
|
89
|
+
raise ValueError("CEL based keys requires limit to be set")
|
|
90
|
+
|
|
91
|
+
return self
|
|
92
|
+
|
|
93
|
+
def to_proto(self) -> CreateTaskRateLimit:
|
|
75
94
|
key = self.static_key
|
|
76
95
|
key_expression = self.dynamic_key
|
|
77
96
|
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
raise ValueError(f"Invalid CEL expression: {key_expression}")
|
|
85
|
-
|
|
86
|
-
# juggle the units and units_expr fields
|
|
87
|
-
units = None
|
|
88
|
-
units_expression = None
|
|
89
|
-
if isinstance(self.units, int):
|
|
90
|
-
units = self.units
|
|
91
|
-
else:
|
|
92
|
-
if not validate_cel_expression(self.units):
|
|
93
|
-
raise ValueError(f"Invalid CEL expression: {self.units}")
|
|
94
|
-
units_expression = self.units
|
|
95
|
-
|
|
96
|
-
# juggle the limit and limit_expr fields
|
|
97
|
-
limit_expression = None
|
|
98
|
-
|
|
99
|
-
if self.limit:
|
|
100
|
-
if isinstance(self.limit, int):
|
|
101
|
-
limit_expression = f"{self.limit}"
|
|
102
|
-
else:
|
|
103
|
-
if not validate_cel_expression(self.limit):
|
|
104
|
-
raise ValueError(f"Invalid CEL expression: {self.limit}")
|
|
105
|
-
limit_expression = self.limit
|
|
106
|
-
|
|
107
|
-
if key_expression is not None and limit_expression is None:
|
|
108
|
-
raise ValueError("CEL based keys requires limit to be set")
|
|
97
|
+
key = self.static_key or self.dynamic_key
|
|
98
|
+
|
|
99
|
+
units = self.units if isinstance(self.units, int) else None
|
|
100
|
+
units_expression = None if isinstance(self.units, int) else self.units
|
|
101
|
+
|
|
102
|
+
limit_expression = None if not self.limit else str(self.limit)
|
|
109
103
|
|
|
110
|
-
|
|
104
|
+
return CreateTaskRateLimit(
|
|
111
105
|
key=key,
|
|
112
106
|
key_expr=key_expression,
|
|
113
107
|
units=units,
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from collections import Counter
|
|
3
|
+
from contextvars import ContextVar
|
|
4
|
+
|
|
5
|
+
ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
|
|
6
|
+
"ctx_workflow_run_id", default=None
|
|
7
|
+
)
|
|
8
|
+
ctx_step_run_id: ContextVar[str | None] = ContextVar("ctx_step_run_id", default=None)
|
|
9
|
+
ctx_worker_id: ContextVar[str | None] = ContextVar("ctx_worker_id", default=None)
|
|
10
|
+
|
|
11
|
+
workflow_spawn_indices = Counter[str]()
|
|
12
|
+
spawn_index_lock = asyncio.Lock()
|
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from typing import Any, Generic, cast, get_type_hints
|
|
4
|
+
|
|
5
|
+
from google.protobuf import timestamp_pb2
|
|
6
|
+
|
|
7
|
+
from hatchet_sdk.clients.admin import (
|
|
8
|
+
ScheduleTriggerWorkflowOptions,
|
|
9
|
+
TriggerWorkflowOptions,
|
|
10
|
+
WorkflowRunTriggerConfig,
|
|
11
|
+
)
|
|
12
|
+
from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows
|
|
13
|
+
from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
|
|
14
|
+
from hatchet_sdk.runnables.task import Task
|
|
15
|
+
from hatchet_sdk.runnables.types import R, TWorkflowInput
|
|
16
|
+
from hatchet_sdk.runnables.workflow import BaseWorkflow, Workflow
|
|
17
|
+
from hatchet_sdk.utils.aio_utils import get_active_event_loop
|
|
18
|
+
from hatchet_sdk.utils.typing import JSONSerializableMapping, is_basemodel_subclass
|
|
19
|
+
from hatchet_sdk.workflow_run import WorkflowRunRef
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TaskRunRef(Generic[TWorkflowInput, R]):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
standalone: "Standalone[TWorkflowInput, R]",
|
|
26
|
+
workflow_run_ref: WorkflowRunRef,
|
|
27
|
+
):
|
|
28
|
+
self._s = standalone
|
|
29
|
+
self._wrr = workflow_run_ref
|
|
30
|
+
|
|
31
|
+
async def aio_result(self) -> R:
|
|
32
|
+
result = await self._wrr.workflow_listener.result(self._wrr.workflow_run_id)
|
|
33
|
+
return self._s._extract_result(result)
|
|
34
|
+
|
|
35
|
+
def result(self) -> R:
|
|
36
|
+
coro = self._wrr.workflow_listener.result(self._wrr.workflow_run_id)
|
|
37
|
+
|
|
38
|
+
loop = get_active_event_loop()
|
|
39
|
+
|
|
40
|
+
if loop is None:
|
|
41
|
+
loop = asyncio.new_event_loop()
|
|
42
|
+
asyncio.set_event_loop(loop)
|
|
43
|
+
try:
|
|
44
|
+
result = loop.run_until_complete(coro)
|
|
45
|
+
finally:
|
|
46
|
+
asyncio.set_event_loop(None)
|
|
47
|
+
else:
|
|
48
|
+
result = loop.run_until_complete(coro)
|
|
49
|
+
|
|
50
|
+
return self._s._extract_result(result)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
|
|
54
|
+
def __init__(
|
|
55
|
+
self, workflow: Workflow[TWorkflowInput], task: Task[TWorkflowInput, R]
|
|
56
|
+
) -> None:
|
|
57
|
+
super().__init__(config=workflow.config, client=workflow.client)
|
|
58
|
+
|
|
59
|
+
## NOTE: This is a hack to assign the task back to the base workflow,
|
|
60
|
+
## since the decorator to mutate the tasks is not being called.
|
|
61
|
+
self._default_tasks = [task]
|
|
62
|
+
|
|
63
|
+
self._workflow = workflow
|
|
64
|
+
self._task = task
|
|
65
|
+
|
|
66
|
+
return_type = get_type_hints(self._task.fn).get("return")
|
|
67
|
+
|
|
68
|
+
self._output_validator = (
|
|
69
|
+
return_type if is_basemodel_subclass(return_type) else None
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
self.config = self._workflow.config
|
|
73
|
+
|
|
74
|
+
def _extract_result(self, result: dict[str, Any]) -> R:
|
|
75
|
+
output = result.get(self._task.name)
|
|
76
|
+
|
|
77
|
+
if not self._output_validator:
|
|
78
|
+
return cast(R, output)
|
|
79
|
+
|
|
80
|
+
return cast(R, self._output_validator.model_validate(output))
|
|
81
|
+
|
|
82
|
+
def run(
|
|
83
|
+
self,
|
|
84
|
+
input: TWorkflowInput | None = None,
|
|
85
|
+
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
86
|
+
) -> R:
|
|
87
|
+
return self._extract_result(self._workflow.run(input, options))
|
|
88
|
+
|
|
89
|
+
async def aio_run(
|
|
90
|
+
self,
|
|
91
|
+
input: TWorkflowInput | None = None,
|
|
92
|
+
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
93
|
+
) -> R:
|
|
94
|
+
result = await self._workflow.aio_run(input, options)
|
|
95
|
+
return self._extract_result(result)
|
|
96
|
+
|
|
97
|
+
def run_no_wait(
|
|
98
|
+
self,
|
|
99
|
+
input: TWorkflowInput | None = None,
|
|
100
|
+
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
101
|
+
) -> TaskRunRef[TWorkflowInput, R]:
|
|
102
|
+
ref = self._workflow.run_no_wait(input, options)
|
|
103
|
+
|
|
104
|
+
return TaskRunRef[TWorkflowInput, R](self, ref)
|
|
105
|
+
|
|
106
|
+
async def aio_run_no_wait(
|
|
107
|
+
self,
|
|
108
|
+
input: TWorkflowInput | None = None,
|
|
109
|
+
options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
|
|
110
|
+
) -> TaskRunRef[TWorkflowInput, R]:
|
|
111
|
+
ref = await self._workflow.aio_run_no_wait(input, options)
|
|
112
|
+
|
|
113
|
+
return TaskRunRef[TWorkflowInput, R](self, ref)
|
|
114
|
+
|
|
115
|
+
def run_many(self, workflows: list[WorkflowRunTriggerConfig]) -> list[R]:
|
|
116
|
+
return [
|
|
117
|
+
self._extract_result(result)
|
|
118
|
+
for result in self._workflow.run_many(workflows)
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
async def aio_run_many(self, workflows: list[WorkflowRunTriggerConfig]) -> list[R]:
|
|
122
|
+
return [
|
|
123
|
+
self._extract_result(result)
|
|
124
|
+
for result in await self._workflow.aio_run_many(workflows)
|
|
125
|
+
]
|
|
126
|
+
|
|
127
|
+
def run_many_no_wait(
|
|
128
|
+
self, workflows: list[WorkflowRunTriggerConfig]
|
|
129
|
+
) -> list[TaskRunRef[TWorkflowInput, R]]:
|
|
130
|
+
refs = self._workflow.run_many_no_wait(workflows)
|
|
131
|
+
|
|
132
|
+
return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
|
|
133
|
+
|
|
134
|
+
async def aio_run_many_no_wait(
|
|
135
|
+
self, workflows: list[WorkflowRunTriggerConfig]
|
|
136
|
+
) -> list[TaskRunRef[TWorkflowInput, R]]:
|
|
137
|
+
refs = await self._workflow.aio_run_many_no_wait(workflows)
|
|
138
|
+
|
|
139
|
+
return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
|
|
140
|
+
|
|
141
|
+
def schedule(
|
|
142
|
+
self,
|
|
143
|
+
schedules: list[datetime],
|
|
144
|
+
input: TWorkflowInput | None = None,
|
|
145
|
+
options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
|
|
146
|
+
) -> WorkflowVersion:
|
|
147
|
+
return self._workflow.schedule(
|
|
148
|
+
schedules=schedules,
|
|
149
|
+
input=input,
|
|
150
|
+
options=options,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
async def aio_schedule(
|
|
154
|
+
self,
|
|
155
|
+
schedules: list[datetime | timestamp_pb2.Timestamp],
|
|
156
|
+
input: TWorkflowInput,
|
|
157
|
+
options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
|
|
158
|
+
) -> WorkflowVersion:
|
|
159
|
+
return await self._workflow.aio_schedule(
|
|
160
|
+
schedules=schedules,
|
|
161
|
+
input=input,
|
|
162
|
+
options=options,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
def create_cron(
|
|
166
|
+
self,
|
|
167
|
+
cron_name: str,
|
|
168
|
+
expression: str,
|
|
169
|
+
input: TWorkflowInput,
|
|
170
|
+
additional_metadata: JSONSerializableMapping,
|
|
171
|
+
) -> CronWorkflows:
|
|
172
|
+
return self._workflow.create_cron(
|
|
173
|
+
cron_name=cron_name,
|
|
174
|
+
expression=expression,
|
|
175
|
+
input=input,
|
|
176
|
+
additional_metadata=additional_metadata,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
async def aio_create_cron(
|
|
180
|
+
self,
|
|
181
|
+
cron_name: str,
|
|
182
|
+
expression: str,
|
|
183
|
+
input: TWorkflowInput,
|
|
184
|
+
additional_metadata: JSONSerializableMapping,
|
|
185
|
+
) -> CronWorkflows:
|
|
186
|
+
return await self._workflow.aio_create_cron(
|
|
187
|
+
cron_name=cron_name,
|
|
188
|
+
expression=expression,
|
|
189
|
+
input=input,
|
|
190
|
+
additional_metadata=additional_metadata,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def to_task(self) -> Task[TWorkflowInput, R]:
|
|
194
|
+
return self._task
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from typing import (
|
|
2
|
+
TYPE_CHECKING,
|
|
3
|
+
Any,
|
|
4
|
+
Awaitable,
|
|
5
|
+
Callable,
|
|
6
|
+
Generic,
|
|
7
|
+
TypeVar,
|
|
8
|
+
Union,
|
|
9
|
+
cast,
|
|
10
|
+
)
|
|
11
|
+
|
|
12
|
+
from hatchet_sdk.context.context import Context, DurableContext
|
|
13
|
+
from hatchet_sdk.contracts.v1.workflows_pb2 import (
|
|
14
|
+
CreateTaskRateLimit,
|
|
15
|
+
DesiredWorkerLabels,
|
|
16
|
+
)
|
|
17
|
+
from hatchet_sdk.runnables.types import (
|
|
18
|
+
DEFAULT_EXECUTION_TIMEOUT,
|
|
19
|
+
DEFAULT_SCHEDULE_TIMEOUT,
|
|
20
|
+
ConcurrencyExpression,
|
|
21
|
+
R,
|
|
22
|
+
StepType,
|
|
23
|
+
TWorkflowInput,
|
|
24
|
+
is_async_fn,
|
|
25
|
+
is_durable_sync_fn,
|
|
26
|
+
is_sync_fn,
|
|
27
|
+
)
|
|
28
|
+
from hatchet_sdk.utils.timedelta_to_expression import Duration
|
|
29
|
+
from hatchet_sdk.waits import Condition, OrGroup
|
|
30
|
+
|
|
31
|
+
if TYPE_CHECKING:
|
|
32
|
+
from hatchet_sdk.runnables.workflow import Workflow
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
T = TypeVar("T")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def fall_back_to_default(value: T, default: T, fallback_value: T) -> T:
|
|
39
|
+
## If the value is not the default, it's set
|
|
40
|
+
if value != default:
|
|
41
|
+
return value
|
|
42
|
+
|
|
43
|
+
## Otherwise, it's unset, so return the fallback value
|
|
44
|
+
return fallback_value
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Task(Generic[TWorkflowInput, R]):
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
_fn: Union[
|
|
51
|
+
Callable[[TWorkflowInput, Context], R]
|
|
52
|
+
| Callable[[TWorkflowInput, Context], Awaitable[R]],
|
|
53
|
+
Callable[[TWorkflowInput, DurableContext], R]
|
|
54
|
+
| Callable[[TWorkflowInput, DurableContext], Awaitable[R]],
|
|
55
|
+
],
|
|
56
|
+
is_durable: bool,
|
|
57
|
+
type: StepType,
|
|
58
|
+
workflow: "Workflow[TWorkflowInput]",
|
|
59
|
+
name: str,
|
|
60
|
+
execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT,
|
|
61
|
+
schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT,
|
|
62
|
+
parents: "list[Task[TWorkflowInput, Any]]" = [],
|
|
63
|
+
retries: int = 0,
|
|
64
|
+
rate_limits: list[CreateTaskRateLimit] = [],
|
|
65
|
+
desired_worker_labels: dict[str, DesiredWorkerLabels] = {},
|
|
66
|
+
backoff_factor: float | None = None,
|
|
67
|
+
backoff_max_seconds: int | None = None,
|
|
68
|
+
concurrency: list[ConcurrencyExpression] = [],
|
|
69
|
+
wait_for: list[Condition | OrGroup] = [],
|
|
70
|
+
skip_if: list[Condition | OrGroup] = [],
|
|
71
|
+
cancel_if: list[Condition | OrGroup] = [],
|
|
72
|
+
) -> None:
|
|
73
|
+
self.is_durable = is_durable
|
|
74
|
+
|
|
75
|
+
self.fn = _fn
|
|
76
|
+
self.is_async_function = is_async_fn(self.fn) # type: ignore
|
|
77
|
+
|
|
78
|
+
self.workflow = workflow
|
|
79
|
+
|
|
80
|
+
self.type = type
|
|
81
|
+
self.execution_timeout = fall_back_to_default(
|
|
82
|
+
execution_timeout, DEFAULT_EXECUTION_TIMEOUT, DEFAULT_EXECUTION_TIMEOUT
|
|
83
|
+
)
|
|
84
|
+
self.schedule_timeout = fall_back_to_default(
|
|
85
|
+
schedule_timeout, DEFAULT_SCHEDULE_TIMEOUT, DEFAULT_SCHEDULE_TIMEOUT
|
|
86
|
+
)
|
|
87
|
+
self.name = name
|
|
88
|
+
self.parents = parents
|
|
89
|
+
self.retries = retries
|
|
90
|
+
self.rate_limits = rate_limits
|
|
91
|
+
self.desired_worker_labels = desired_worker_labels
|
|
92
|
+
self.backoff_factor = backoff_factor
|
|
93
|
+
self.backoff_max_seconds = backoff_max_seconds
|
|
94
|
+
self.concurrency = concurrency
|
|
95
|
+
|
|
96
|
+
self.wait_for = self._flatten_conditions(wait_for)
|
|
97
|
+
self.skip_if = self._flatten_conditions(skip_if)
|
|
98
|
+
self.cancel_if = self._flatten_conditions(cancel_if)
|
|
99
|
+
|
|
100
|
+
def _flatten_conditions(
|
|
101
|
+
self, conditions: list[Condition | OrGroup]
|
|
102
|
+
) -> list[Condition]:
|
|
103
|
+
flattened: list[Condition] = []
|
|
104
|
+
|
|
105
|
+
for condition in conditions:
|
|
106
|
+
if isinstance(condition, OrGroup):
|
|
107
|
+
for or_condition in condition.conditions:
|
|
108
|
+
or_condition.base.or_group_id = condition.or_group_id
|
|
109
|
+
|
|
110
|
+
flattened.extend(condition.conditions)
|
|
111
|
+
else:
|
|
112
|
+
flattened.append(condition)
|
|
113
|
+
|
|
114
|
+
return flattened
|
|
115
|
+
|
|
116
|
+
def call(self, ctx: Context | DurableContext) -> R:
|
|
117
|
+
if self.is_async_function:
|
|
118
|
+
raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
|
|
119
|
+
|
|
120
|
+
workflow_input = self.workflow._get_workflow_input(ctx)
|
|
121
|
+
|
|
122
|
+
if self.is_durable:
|
|
123
|
+
fn = cast(Callable[[TWorkflowInput, DurableContext], R], self.fn)
|
|
124
|
+
if is_durable_sync_fn(fn):
|
|
125
|
+
return fn(workflow_input, cast(DurableContext, ctx))
|
|
126
|
+
else:
|
|
127
|
+
fn = cast(Callable[[TWorkflowInput, Context], R], self.fn)
|
|
128
|
+
if is_sync_fn(fn):
|
|
129
|
+
return fn(workflow_input, cast(Context, ctx))
|
|
130
|
+
|
|
131
|
+
raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
|
|
132
|
+
|
|
133
|
+
async def aio_call(self, ctx: Context | DurableContext) -> R:
|
|
134
|
+
if not self.is_async_function:
|
|
135
|
+
raise TypeError(
|
|
136
|
+
f"{self.name} is not an async function. Use `call` instead."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
workflow_input = self.workflow._get_workflow_input(ctx)
|
|
140
|
+
|
|
141
|
+
if is_async_fn(self.fn): # type: ignore
|
|
142
|
+
return await self.fn(workflow_input, cast(Context, ctx)) # type: ignore
|
|
143
|
+
|
|
144
|
+
raise TypeError(f"{self.name} is not an async function. Use `call` instead.")
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from datetime import timedelta
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeGuard, TypeVar, Union
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field, StrictInt, model_validator
|
|
7
|
+
|
|
8
|
+
from hatchet_sdk.context.context import Context, DurableContext
|
|
9
|
+
from hatchet_sdk.utils.timedelta_to_expression import Duration
|
|
10
|
+
from hatchet_sdk.utils.typing import JSONSerializableMapping
|
|
11
|
+
|
|
12
|
+
ValidTaskReturnType = Union[BaseModel, JSONSerializableMapping, None]
|
|
13
|
+
|
|
14
|
+
R = TypeVar("R", bound=Union[ValidTaskReturnType, Awaitable[ValidTaskReturnType]])
|
|
15
|
+
P = ParamSpec("P")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DEFAULT_EXECUTION_TIMEOUT = timedelta(seconds=60)
|
|
19
|
+
DEFAULT_SCHEDULE_TIMEOUT = timedelta(minutes=5)
|
|
20
|
+
DEFAULT_PRIORITY = 1
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class EmptyModel(BaseModel):
|
|
24
|
+
model_config = ConfigDict(extra="allow")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class StickyStrategy(str, Enum):
|
|
28
|
+
SOFT = "SOFT"
|
|
29
|
+
HARD = "HARD"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ConcurrencyLimitStrategy(str, Enum):
|
|
33
|
+
CANCEL_IN_PROGRESS = "CANCEL_IN_PROGRESS"
|
|
34
|
+
DROP_NEWEST = "DROP_NEWEST"
|
|
35
|
+
QUEUE_NEWEST = "QUEUE_NEWEST"
|
|
36
|
+
GROUP_ROUND_ROBIN = "GROUP_ROUND_ROBIN"
|
|
37
|
+
CANCEL_NEWEST = "CANCEL_NEWEST"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ConcurrencyExpression(BaseModel):
|
|
41
|
+
"""
|
|
42
|
+
Defines concurrency limits for a workflow using a CEL expression.
|
|
43
|
+
Args:
|
|
44
|
+
expression (str): CEL expression to determine concurrency grouping. (i.e. "input.user_id")
|
|
45
|
+
max_runs (int): Maximum number of concurrent workflow runs.
|
|
46
|
+
limit_strategy (ConcurrencyLimitStrategy): Strategy for handling limit violations.
|
|
47
|
+
Example:
|
|
48
|
+
ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS)
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
expression: str
|
|
52
|
+
max_runs: int
|
|
53
|
+
limit_strategy: ConcurrencyLimitStrategy
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
TWorkflowInput = TypeVar("TWorkflowInput", bound=BaseModel)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class TaskDefaults(BaseModel):
|
|
60
|
+
schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT
|
|
61
|
+
execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT
|
|
62
|
+
priority: StrictInt = Field(gt=0, lt=4, default=DEFAULT_PRIORITY)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class WorkflowConfig(BaseModel):
|
|
66
|
+
model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
|
|
67
|
+
|
|
68
|
+
name: str
|
|
69
|
+
description: str | None = None
|
|
70
|
+
version: str | None = None
|
|
71
|
+
on_events: list[str] = Field(default_factory=list)
|
|
72
|
+
on_crons: list[str] = Field(default_factory=list)
|
|
73
|
+
sticky: StickyStrategy | None = None
|
|
74
|
+
concurrency: ConcurrencyExpression | None = None
|
|
75
|
+
input_validator: Type[BaseModel] = EmptyModel
|
|
76
|
+
|
|
77
|
+
task_defaults: TaskDefaults = TaskDefaults()
|
|
78
|
+
|
|
79
|
+
@model_validator(mode="after")
|
|
80
|
+
def validate_concurrency_expression(self) -> "WorkflowConfig":
|
|
81
|
+
if not self.concurrency:
|
|
82
|
+
return self
|
|
83
|
+
|
|
84
|
+
expr = self.concurrency.expression
|
|
85
|
+
|
|
86
|
+
if not expr.startswith("input."):
|
|
87
|
+
return self
|
|
88
|
+
|
|
89
|
+
_, field = expr.split(".", maxsplit=2)
|
|
90
|
+
|
|
91
|
+
if field not in self.input_validator.model_fields.keys():
|
|
92
|
+
raise ValueError(
|
|
93
|
+
f"The concurrency expression provided relies on the `{field}` field, which was not present in `{self.input_validator.__name__}`."
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
return self
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class StepType(str, Enum):
|
|
100
|
+
DEFAULT = "default"
|
|
101
|
+
ON_FAILURE = "on_failure"
|
|
102
|
+
ON_SUCCESS = "on_success"
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
AsyncFunc = Callable[[TWorkflowInput, Context], Awaitable[R]]
|
|
106
|
+
SyncFunc = Callable[[TWorkflowInput, Context], R]
|
|
107
|
+
TaskFunc = Union[AsyncFunc[TWorkflowInput, R], SyncFunc[TWorkflowInput, R]]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def is_async_fn(
|
|
111
|
+
fn: TaskFunc[TWorkflowInput, R]
|
|
112
|
+
) -> TypeGuard[AsyncFunc[TWorkflowInput, R]]:
|
|
113
|
+
return asyncio.iscoroutinefunction(fn)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def is_sync_fn(
|
|
117
|
+
fn: TaskFunc[TWorkflowInput, R]
|
|
118
|
+
) -> TypeGuard[SyncFunc[TWorkflowInput, R]]:
|
|
119
|
+
return not asyncio.iscoroutinefunction(fn)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
DurableAsyncFunc = Callable[[TWorkflowInput, DurableContext], Awaitable[R]]
|
|
123
|
+
DurableSyncFunc = Callable[[TWorkflowInput, DurableContext], R]
|
|
124
|
+
DurableTaskFunc = Union[
|
|
125
|
+
DurableAsyncFunc[TWorkflowInput, R], DurableSyncFunc[TWorkflowInput, R]
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def is_durable_async_fn(
|
|
130
|
+
fn: Callable[..., Any]
|
|
131
|
+
) -> TypeGuard[DurableAsyncFunc[TWorkflowInput, R]]:
|
|
132
|
+
return asyncio.iscoroutinefunction(fn)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def is_durable_sync_fn(
|
|
136
|
+
fn: DurableTaskFunc[TWorkflowInput, R]
|
|
137
|
+
) -> TypeGuard[DurableSyncFunc[TWorkflowInput, R]]:
|
|
138
|
+
return not asyncio.iscoroutinefunction(fn)
|