hatchet-sdk 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (73) hide show
  1. hatchet_sdk/__init__.py +32 -16
  2. hatchet_sdk/client.py +25 -63
  3. hatchet_sdk/clients/admin.py +203 -142
  4. hatchet_sdk/clients/dispatcher/action_listener.py +42 -42
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +18 -16
  6. hatchet_sdk/clients/durable_event_listener.py +327 -0
  7. hatchet_sdk/clients/rest/__init__.py +12 -1
  8. hatchet_sdk/clients/rest/api/log_api.py +258 -0
  9. hatchet_sdk/clients/rest/api/task_api.py +32 -6
  10. hatchet_sdk/clients/rest/api/workflow_runs_api.py +626 -0
  11. hatchet_sdk/clients/rest/models/__init__.py +12 -1
  12. hatchet_sdk/clients/rest/models/v1_log_line.py +94 -0
  13. hatchet_sdk/clients/rest/models/v1_log_line_level.py +39 -0
  14. hatchet_sdk/clients/rest/models/v1_log_line_list.py +110 -0
  15. hatchet_sdk/clients/rest/models/v1_task_summary.py +80 -64
  16. hatchet_sdk/clients/rest/models/v1_trigger_workflow_run_request.py +95 -0
  17. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name.py +98 -0
  18. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name_list.py +114 -0
  19. hatchet_sdk/clients/rest/models/workflow_run_shape_item_for_workflow_run_details.py +9 -4
  20. hatchet_sdk/clients/rest/models/workflow_runs_metrics.py +5 -1
  21. hatchet_sdk/clients/run_event_listener.py +0 -1
  22. hatchet_sdk/clients/v1/api_client.py +81 -0
  23. hatchet_sdk/context/context.py +86 -159
  24. hatchet_sdk/contracts/dispatcher_pb2_grpc.py +1 -1
  25. hatchet_sdk/contracts/events_pb2.py +2 -2
  26. hatchet_sdk/contracts/events_pb2_grpc.py +1 -1
  27. hatchet_sdk/contracts/v1/dispatcher_pb2.py +36 -0
  28. hatchet_sdk/contracts/v1/dispatcher_pb2.pyi +38 -0
  29. hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +145 -0
  30. hatchet_sdk/contracts/v1/shared/condition_pb2.py +39 -0
  31. hatchet_sdk/contracts/v1/shared/condition_pb2.pyi +72 -0
  32. hatchet_sdk/contracts/v1/shared/condition_pb2_grpc.py +29 -0
  33. hatchet_sdk/contracts/v1/workflows_pb2.py +67 -0
  34. hatchet_sdk/contracts/v1/workflows_pb2.pyi +228 -0
  35. hatchet_sdk/contracts/v1/workflows_pb2_grpc.py +234 -0
  36. hatchet_sdk/contracts/workflows_pb2_grpc.py +1 -1
  37. hatchet_sdk/features/cron.py +91 -121
  38. hatchet_sdk/features/logs.py +16 -0
  39. hatchet_sdk/features/metrics.py +75 -0
  40. hatchet_sdk/features/rate_limits.py +45 -0
  41. hatchet_sdk/features/runs.py +221 -0
  42. hatchet_sdk/features/scheduled.py +114 -131
  43. hatchet_sdk/features/workers.py +41 -0
  44. hatchet_sdk/features/workflows.py +55 -0
  45. hatchet_sdk/hatchet.py +463 -165
  46. hatchet_sdk/opentelemetry/instrumentor.py +8 -13
  47. hatchet_sdk/rate_limit.py +33 -39
  48. hatchet_sdk/runnables/contextvars.py +12 -0
  49. hatchet_sdk/runnables/standalone.py +192 -0
  50. hatchet_sdk/runnables/task.py +144 -0
  51. hatchet_sdk/runnables/types.py +138 -0
  52. hatchet_sdk/runnables/workflow.py +771 -0
  53. hatchet_sdk/utils/aio_utils.py +0 -79
  54. hatchet_sdk/utils/proto_enums.py +0 -7
  55. hatchet_sdk/utils/timedelta_to_expression.py +23 -0
  56. hatchet_sdk/utils/typing.py +2 -2
  57. hatchet_sdk/v0/clients/rest_client.py +9 -0
  58. hatchet_sdk/v0/worker/action_listener_process.py +18 -2
  59. hatchet_sdk/waits.py +120 -0
  60. hatchet_sdk/worker/action_listener_process.py +64 -30
  61. hatchet_sdk/worker/runner/run_loop_manager.py +35 -26
  62. hatchet_sdk/worker/runner/runner.py +72 -55
  63. hatchet_sdk/worker/runner/utils/capture_logs.py +3 -11
  64. hatchet_sdk/worker/worker.py +155 -118
  65. hatchet_sdk/workflow_run.py +4 -5
  66. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/METADATA +1 -2
  67. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/RECORD +69 -43
  68. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/entry_points.txt +2 -0
  69. hatchet_sdk/clients/rest_client.py +0 -636
  70. hatchet_sdk/semver.py +0 -30
  71. hatchet_sdk/worker/runner/utils/error_with_traceback.py +0 -6
  72. hatchet_sdk/workflow.py +0 -527
  73. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/WHEEL +0 -0
@@ -30,7 +30,7 @@ import hatchet_sdk
30
30
  from hatchet_sdk.clients.admin import (
31
31
  AdminClient,
32
32
  TriggerWorkflowOptions,
33
- WorkflowRunDict,
33
+ WorkflowRunTriggerConfig,
34
34
  )
35
35
  from hatchet_sdk.clients.dispatcher.action_listener import Action
36
36
  from hatchet_sdk.clients.events import (
@@ -352,14 +352,12 @@ class HatchetInstrumentor(BaseInstrumentor): # type: ignore[misc]
352
352
  def _wrap_run_workflows(
353
353
  self,
354
354
  wrapped: Callable[
355
- [list[WorkflowRunDict], TriggerWorkflowOptions | None], list[WorkflowRunRef]
355
+ [list[WorkflowRunTriggerConfig]],
356
+ list[WorkflowRunRef],
356
357
  ],
357
358
  instance: AdminClient,
358
- args: tuple[
359
- list[WorkflowRunDict],
360
- TriggerWorkflowOptions | None,
361
- ],
362
- kwargs: dict[str, list[WorkflowRunDict] | TriggerWorkflowOptions | None],
359
+ args: tuple[list[WorkflowRunTriggerConfig],],
360
+ kwargs: dict[str, list[WorkflowRunTriggerConfig]],
363
361
  ) -> list[WorkflowRunRef]:
364
362
  with self._tracer.start_as_current_span(
365
363
  "hatchet.run_workflows",
@@ -370,15 +368,12 @@ class HatchetInstrumentor(BaseInstrumentor): # type: ignore[misc]
370
368
  async def _wrap_async_run_workflows(
371
369
  self,
372
370
  wrapped: Callable[
373
- [list[WorkflowRunDict], TriggerWorkflowOptions | None],
371
+ [list[WorkflowRunTriggerConfig]],
374
372
  Coroutine[None, None, list[WorkflowRunRef]],
375
373
  ],
376
374
  instance: AdminClient,
377
- args: tuple[
378
- list[WorkflowRunDict],
379
- TriggerWorkflowOptions | None,
380
- ],
381
- kwargs: dict[str, list[WorkflowRunDict] | TriggerWorkflowOptions | None],
375
+ args: tuple[list[WorkflowRunTriggerConfig],],
376
+ kwargs: dict[str, list[WorkflowRunTriggerConfig]],
382
377
  ) -> list[WorkflowRunRef]:
383
378
  with self._tracer.start_as_current_span(
384
379
  "hatchet.run_workflows",
hatchet_sdk/rate_limit.py CHANGED
@@ -1,9 +1,9 @@
1
- from dataclasses import dataclass
2
1
  from enum import Enum
3
2
 
4
3
  from celpy import CELEvalError, Environment # type: ignore
4
+ from pydantic import BaseModel, model_validator
5
5
 
6
- from hatchet_sdk.contracts.workflows_pb2 import CreateStepRateLimit
6
+ from hatchet_sdk.contracts.v1.workflows_pb2 import CreateTaskRateLimit
7
7
 
8
8
 
9
9
  def validate_cel_expression(expr: str) -> bool:
@@ -25,8 +25,7 @@ class RateLimitDuration(str, Enum):
25
25
  YEAR = "YEAR"
26
26
 
27
27
 
28
- @dataclass
29
- class RateLimit:
28
+ class RateLimit(BaseModel):
30
29
  """
31
30
  Represents a rate limit configuration for a step in a workflow.
32
31
 
@@ -68,46 +67,41 @@ class RateLimit:
68
67
  limit: int | str | None = None
69
68
  duration: RateLimitDuration = RateLimitDuration.MINUTE
70
69
 
71
- _req: CreateStepRateLimit | None = None
70
+ @model_validator(mode="after")
71
+ def validate_rate_limit(self) -> "RateLimit":
72
+ if self.dynamic_key and self.static_key:
73
+ raise ValueError("Cannot have both static key and dynamic key set")
72
74
 
73
- def __post_init__(self) -> None:
74
- # juggle the key and key_expr fields
75
+ if self.dynamic_key and not validate_cel_expression(self.dynamic_key):
76
+ raise ValueError(f"Invalid CEL expression: {self.dynamic_key}")
77
+
78
+ if not isinstance(self.units, int) and not validate_cel_expression(self.units):
79
+ raise ValueError(f"Invalid CEL expression: {self.units}")
80
+
81
+ if (
82
+ self.limit
83
+ and not isinstance(self.limit, int)
84
+ and not validate_cel_expression(self.limit)
85
+ ):
86
+ raise ValueError(f"Invalid CEL expression: {self.limit}")
87
+
88
+ if self.dynamic_key and not self.limit:
89
+ raise ValueError("CEL based keys requires limit to be set")
90
+
91
+ return self
92
+
93
+ def to_proto(self) -> CreateTaskRateLimit:
75
94
  key = self.static_key
76
95
  key_expression = self.dynamic_key
77
96
 
78
- if key_expression is not None:
79
- if key is not None:
80
- raise ValueError("Cannot have both static key and dynamic key set")
81
-
82
- key = key_expression
83
- if not validate_cel_expression(key_expression):
84
- raise ValueError(f"Invalid CEL expression: {key_expression}")
85
-
86
- # juggle the units and units_expr fields
87
- units = None
88
- units_expression = None
89
- if isinstance(self.units, int):
90
- units = self.units
91
- else:
92
- if not validate_cel_expression(self.units):
93
- raise ValueError(f"Invalid CEL expression: {self.units}")
94
- units_expression = self.units
95
-
96
- # juggle the limit and limit_expr fields
97
- limit_expression = None
98
-
99
- if self.limit:
100
- if isinstance(self.limit, int):
101
- limit_expression = f"{self.limit}"
102
- else:
103
- if not validate_cel_expression(self.limit):
104
- raise ValueError(f"Invalid CEL expression: {self.limit}")
105
- limit_expression = self.limit
106
-
107
- if key_expression is not None and limit_expression is None:
108
- raise ValueError("CEL based keys requires limit to be set")
97
+ key = self.static_key or self.dynamic_key
98
+
99
+ units = self.units if isinstance(self.units, int) else None
100
+ units_expression = None if isinstance(self.units, int) else self.units
101
+
102
+ limit_expression = None if not self.limit else str(self.limit)
109
103
 
110
- self._req = CreateStepRateLimit(
104
+ return CreateTaskRateLimit(
111
105
  key=key,
112
106
  key_expr=key_expression,
113
107
  units=units,
@@ -0,0 +1,12 @@
1
+ import asyncio
2
+ from collections import Counter
3
+ from contextvars import ContextVar
4
+
5
+ ctx_workflow_run_id: ContextVar[str | None] = ContextVar(
6
+ "ctx_workflow_run_id", default=None
7
+ )
8
+ ctx_step_run_id: ContextVar[str | None] = ContextVar("ctx_step_run_id", default=None)
9
+ ctx_worker_id: ContextVar[str | None] = ContextVar("ctx_worker_id", default=None)
10
+
11
+ workflow_spawn_indices = Counter[str]()
12
+ spawn_index_lock = asyncio.Lock()
@@ -0,0 +1,192 @@
1
+ import asyncio
2
+ from datetime import datetime
3
+ from typing import Any, Generic, cast, get_type_hints
4
+
5
+ from hatchet_sdk.clients.admin import (
6
+ ScheduleTriggerWorkflowOptions,
7
+ TriggerWorkflowOptions,
8
+ WorkflowRunTriggerConfig,
9
+ )
10
+ from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows
11
+ from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
12
+ from hatchet_sdk.runnables.task import Task
13
+ from hatchet_sdk.runnables.types import EmptyModel, R, TWorkflowInput
14
+ from hatchet_sdk.runnables.workflow import BaseWorkflow, Workflow
15
+ from hatchet_sdk.utils.aio_utils import get_active_event_loop
16
+ from hatchet_sdk.utils.typing import JSONSerializableMapping, is_basemodel_subclass
17
+ from hatchet_sdk.workflow_run import WorkflowRunRef
18
+
19
+
20
+ class TaskRunRef(Generic[TWorkflowInput, R]):
21
+ def __init__(
22
+ self,
23
+ standalone: "Standalone[TWorkflowInput, R]",
24
+ workflow_run_ref: WorkflowRunRef,
25
+ ):
26
+ self._s = standalone
27
+ self._wrr = workflow_run_ref
28
+
29
+ async def aio_result(self) -> R:
30
+ result = await self._wrr.workflow_listener.result(self._wrr.workflow_run_id)
31
+ return self._s._extract_result(result)
32
+
33
+ def result(self) -> R:
34
+ coro = self._wrr.workflow_listener.result(self._wrr.workflow_run_id)
35
+
36
+ loop = get_active_event_loop()
37
+
38
+ if loop is None:
39
+ loop = asyncio.new_event_loop()
40
+ asyncio.set_event_loop(loop)
41
+ try:
42
+ result = loop.run_until_complete(coro)
43
+ finally:
44
+ asyncio.set_event_loop(None)
45
+ else:
46
+ result = loop.run_until_complete(coro)
47
+
48
+ return self._s._extract_result(result)
49
+
50
+
51
+ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
52
+ def __init__(
53
+ self, workflow: Workflow[TWorkflowInput], task: Task[TWorkflowInput, R]
54
+ ) -> None:
55
+ super().__init__(config=workflow.config, client=workflow.client)
56
+
57
+ ## NOTE: This is a hack to assign the task back to the base workflow,
58
+ ## since the decorator to mutate the tasks is not being called.
59
+ self._default_tasks = [task]
60
+
61
+ self._workflow = workflow
62
+ self._task = task
63
+
64
+ return_type = get_type_hints(self._task.fn).get("return")
65
+
66
+ self._output_validator = (
67
+ return_type if is_basemodel_subclass(return_type) else None
68
+ )
69
+
70
+ self.config = self._workflow.config
71
+
72
+ def _extract_result(self, result: dict[str, Any]) -> R:
73
+ output = result.get(self._task.name)
74
+
75
+ if not self._output_validator:
76
+ return cast(R, output)
77
+
78
+ return cast(R, self._output_validator.model_validate(output))
79
+
80
+ def run(
81
+ self,
82
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
83
+ options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
84
+ ) -> R:
85
+ return self._extract_result(self._workflow.run(input, options))
86
+
87
+ async def aio_run(
88
+ self,
89
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
90
+ options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
91
+ ) -> R:
92
+ result = await self._workflow.aio_run(input, options)
93
+ return self._extract_result(result)
94
+
95
+ def run_no_wait(
96
+ self,
97
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
98
+ options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
99
+ ) -> TaskRunRef[TWorkflowInput, R]:
100
+ ref = self._workflow.run_no_wait(input, options)
101
+
102
+ return TaskRunRef[TWorkflowInput, R](self, ref)
103
+
104
+ async def aio_run_no_wait(
105
+ self,
106
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
107
+ options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
108
+ ) -> TaskRunRef[TWorkflowInput, R]:
109
+ ref = await self._workflow.aio_run_no_wait(input, options)
110
+
111
+ return TaskRunRef[TWorkflowInput, R](self, ref)
112
+
113
+ def run_many(self, workflows: list[WorkflowRunTriggerConfig]) -> list[R]:
114
+ return [
115
+ self._extract_result(result)
116
+ for result in self._workflow.run_many(workflows)
117
+ ]
118
+
119
+ async def aio_run_many(self, workflows: list[WorkflowRunTriggerConfig]) -> list[R]:
120
+ return [
121
+ self._extract_result(result)
122
+ for result in await self._workflow.aio_run_many(workflows)
123
+ ]
124
+
125
+ def run_many_no_wait(
126
+ self, workflows: list[WorkflowRunTriggerConfig]
127
+ ) -> list[TaskRunRef[TWorkflowInput, R]]:
128
+ refs = self._workflow.run_many_no_wait(workflows)
129
+
130
+ return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
131
+
132
+ async def aio_run_many_no_wait(
133
+ self, workflows: list[WorkflowRunTriggerConfig]
134
+ ) -> list[TaskRunRef[TWorkflowInput, R]]:
135
+ refs = await self._workflow.aio_run_many_no_wait(workflows)
136
+
137
+ return [TaskRunRef[TWorkflowInput, R](self, ref) for ref in refs]
138
+
139
+ def schedule(
140
+ self,
141
+ run_at: datetime,
142
+ input: TWorkflowInput | None = None,
143
+ options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
144
+ ) -> WorkflowVersion:
145
+ return self._workflow.schedule(
146
+ run_at=run_at,
147
+ input=input,
148
+ options=options,
149
+ )
150
+
151
+ async def aio_schedule(
152
+ self,
153
+ run_at: datetime,
154
+ input: TWorkflowInput,
155
+ options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
156
+ ) -> WorkflowVersion:
157
+ return await self._workflow.aio_schedule(
158
+ run_at=run_at,
159
+ input=input,
160
+ options=options,
161
+ )
162
+
163
+ def create_cron(
164
+ self,
165
+ cron_name: str,
166
+ expression: str,
167
+ input: TWorkflowInput,
168
+ additional_metadata: JSONSerializableMapping,
169
+ ) -> CronWorkflows:
170
+ return self._workflow.create_cron(
171
+ cron_name=cron_name,
172
+ expression=expression,
173
+ input=input,
174
+ additional_metadata=additional_metadata,
175
+ )
176
+
177
+ async def aio_create_cron(
178
+ self,
179
+ cron_name: str,
180
+ expression: str,
181
+ input: TWorkflowInput,
182
+ additional_metadata: JSONSerializableMapping,
183
+ ) -> CronWorkflows:
184
+ return await self._workflow.aio_create_cron(
185
+ cron_name=cron_name,
186
+ expression=expression,
187
+ input=input,
188
+ additional_metadata=additional_metadata,
189
+ )
190
+
191
+ def to_task(self) -> Task[TWorkflowInput, R]:
192
+ return self._task
@@ -0,0 +1,144 @@
1
+ from typing import (
2
+ TYPE_CHECKING,
3
+ Any,
4
+ Awaitable,
5
+ Callable,
6
+ Generic,
7
+ TypeVar,
8
+ Union,
9
+ cast,
10
+ )
11
+
12
+ from hatchet_sdk.context.context import Context, DurableContext
13
+ from hatchet_sdk.contracts.v1.workflows_pb2 import (
14
+ CreateTaskRateLimit,
15
+ DesiredWorkerLabels,
16
+ )
17
+ from hatchet_sdk.runnables.types import (
18
+ DEFAULT_EXECUTION_TIMEOUT,
19
+ DEFAULT_SCHEDULE_TIMEOUT,
20
+ ConcurrencyExpression,
21
+ R,
22
+ StepType,
23
+ TWorkflowInput,
24
+ is_async_fn,
25
+ is_durable_sync_fn,
26
+ is_sync_fn,
27
+ )
28
+ from hatchet_sdk.utils.timedelta_to_expression import Duration
29
+ from hatchet_sdk.waits import Condition, OrGroup
30
+
31
+ if TYPE_CHECKING:
32
+ from hatchet_sdk.runnables.workflow import Workflow
33
+
34
+
35
+ T = TypeVar("T")
36
+
37
+
38
+ def fall_back_to_default(value: T, default: T, fallback_value: T) -> T:
39
+ ## If the value is not the default, it's set
40
+ if value != default:
41
+ return value
42
+
43
+ ## Otherwise, it's unset, so return the fallback value
44
+ return fallback_value
45
+
46
+
47
+ class Task(Generic[TWorkflowInput, R]):
48
+ def __init__(
49
+ self,
50
+ _fn: Union[
51
+ Callable[[TWorkflowInput, Context], R]
52
+ | Callable[[TWorkflowInput, Context], Awaitable[R]],
53
+ Callable[[TWorkflowInput, DurableContext], R]
54
+ | Callable[[TWorkflowInput, DurableContext], Awaitable[R]],
55
+ ],
56
+ is_durable: bool,
57
+ type: StepType,
58
+ workflow: "Workflow[TWorkflowInput]",
59
+ name: str,
60
+ execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT,
61
+ schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT,
62
+ parents: "list[Task[TWorkflowInput, Any]]" = [],
63
+ retries: int = 0,
64
+ rate_limits: list[CreateTaskRateLimit] = [],
65
+ desired_worker_labels: dict[str, DesiredWorkerLabels] = {},
66
+ backoff_factor: float | None = None,
67
+ backoff_max_seconds: int | None = None,
68
+ concurrency: list[ConcurrencyExpression] = [],
69
+ wait_for: list[Condition | OrGroup] = [],
70
+ skip_if: list[Condition | OrGroup] = [],
71
+ cancel_if: list[Condition | OrGroup] = [],
72
+ ) -> None:
73
+ self.is_durable = is_durable
74
+
75
+ self.fn = _fn
76
+ self.is_async_function = is_async_fn(self.fn) # type: ignore
77
+
78
+ self.workflow = workflow
79
+
80
+ self.type = type
81
+ self.execution_timeout = fall_back_to_default(
82
+ execution_timeout, DEFAULT_EXECUTION_TIMEOUT, DEFAULT_EXECUTION_TIMEOUT
83
+ )
84
+ self.schedule_timeout = fall_back_to_default(
85
+ schedule_timeout, DEFAULT_SCHEDULE_TIMEOUT, DEFAULT_SCHEDULE_TIMEOUT
86
+ )
87
+ self.name = name
88
+ self.parents = parents
89
+ self.retries = retries
90
+ self.rate_limits = rate_limits
91
+ self.desired_worker_labels = desired_worker_labels
92
+ self.backoff_factor = backoff_factor
93
+ self.backoff_max_seconds = backoff_max_seconds
94
+ self.concurrency = concurrency
95
+
96
+ self.wait_for = self._flatten_conditions(wait_for)
97
+ self.skip_if = self._flatten_conditions(skip_if)
98
+ self.cancel_if = self._flatten_conditions(cancel_if)
99
+
100
+ def _flatten_conditions(
101
+ self, conditions: list[Condition | OrGroup]
102
+ ) -> list[Condition]:
103
+ flattened: list[Condition] = []
104
+
105
+ for condition in conditions:
106
+ if isinstance(condition, OrGroup):
107
+ for or_condition in condition.conditions:
108
+ or_condition.base.or_group_id = condition.or_group_id
109
+
110
+ flattened.extend(condition.conditions)
111
+ else:
112
+ flattened.append(condition)
113
+
114
+ return flattened
115
+
116
+ def call(self, ctx: Context | DurableContext) -> R:
117
+ if self.is_async_function:
118
+ raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
119
+
120
+ workflow_input = self.workflow._get_workflow_input(ctx)
121
+
122
+ if self.is_durable:
123
+ fn = cast(Callable[[TWorkflowInput, DurableContext], R], self.fn)
124
+ if is_durable_sync_fn(fn):
125
+ return fn(workflow_input, cast(DurableContext, ctx))
126
+ else:
127
+ fn = cast(Callable[[TWorkflowInput, Context], R], self.fn)
128
+ if is_sync_fn(fn):
129
+ return fn(workflow_input, cast(Context, ctx))
130
+
131
+ raise TypeError(f"{self.name} is not a sync function. Use `acall` instead.")
132
+
133
+ async def aio_call(self, ctx: Context | DurableContext) -> R:
134
+ if not self.is_async_function:
135
+ raise TypeError(
136
+ f"{self.name} is not an async function. Use `call` instead."
137
+ )
138
+
139
+ workflow_input = self.workflow._get_workflow_input(ctx)
140
+
141
+ if is_async_fn(self.fn): # type: ignore
142
+ return await self.fn(workflow_input, cast(Context, ctx)) # type: ignore
143
+
144
+ raise TypeError(f"{self.name} is not an async function. Use `call` instead.")
@@ -0,0 +1,138 @@
1
+ import asyncio
2
+ from datetime import timedelta
3
+ from enum import Enum
4
+ from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeGuard, TypeVar, Union
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field, StrictInt, model_validator
7
+
8
+ from hatchet_sdk.context.context import Context, DurableContext
9
+ from hatchet_sdk.utils.timedelta_to_expression import Duration
10
+ from hatchet_sdk.utils.typing import JSONSerializableMapping
11
+
12
+ ValidTaskReturnType = Union[BaseModel, JSONSerializableMapping, None]
13
+
14
+ R = TypeVar("R", bound=Union[ValidTaskReturnType, Awaitable[ValidTaskReturnType]])
15
+ P = ParamSpec("P")
16
+
17
+
18
+ DEFAULT_EXECUTION_TIMEOUT = timedelta(seconds=60)
19
+ DEFAULT_SCHEDULE_TIMEOUT = timedelta(minutes=5)
20
+ DEFAULT_PRIORITY = 1
21
+
22
+
23
+ class EmptyModel(BaseModel):
24
+ model_config = ConfigDict(extra="allow", frozen=True)
25
+
26
+
27
+ class StickyStrategy(str, Enum):
28
+ SOFT = "SOFT"
29
+ HARD = "HARD"
30
+
31
+
32
+ class ConcurrencyLimitStrategy(str, Enum):
33
+ CANCEL_IN_PROGRESS = "CANCEL_IN_PROGRESS"
34
+ DROP_NEWEST = "DROP_NEWEST"
35
+ QUEUE_NEWEST = "QUEUE_NEWEST"
36
+ GROUP_ROUND_ROBIN = "GROUP_ROUND_ROBIN"
37
+ CANCEL_NEWEST = "CANCEL_NEWEST"
38
+
39
+
40
+ class ConcurrencyExpression(BaseModel):
41
+ """
42
+ Defines concurrency limits for a workflow using a CEL expression.
43
+ Args:
44
+ expression (str): CEL expression to determine concurrency grouping. (i.e. "input.user_id")
45
+ max_runs (int): Maximum number of concurrent workflow runs.
46
+ limit_strategy (ConcurrencyLimitStrategy): Strategy for handling limit violations.
47
+ Example:
48
+ ConcurrencyExpression("input.user_id", 5, ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS)
49
+ """
50
+
51
+ expression: str
52
+ max_runs: int
53
+ limit_strategy: ConcurrencyLimitStrategy
54
+
55
+
56
+ TWorkflowInput = TypeVar("TWorkflowInput", bound=BaseModel)
57
+
58
+
59
+ class TaskDefaults(BaseModel):
60
+ schedule_timeout: Duration = DEFAULT_SCHEDULE_TIMEOUT
61
+ execution_timeout: Duration = DEFAULT_EXECUTION_TIMEOUT
62
+ priority: StrictInt = Field(gt=0, lt=4, default=DEFAULT_PRIORITY)
63
+
64
+
65
+ class WorkflowConfig(BaseModel):
66
+ model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True)
67
+
68
+ name: str
69
+ description: str | None = None
70
+ version: str | None = None
71
+ on_events: list[str] = Field(default_factory=list)
72
+ on_crons: list[str] = Field(default_factory=list)
73
+ sticky: StickyStrategy | None = None
74
+ concurrency: ConcurrencyExpression | None = None
75
+ input_validator: Type[BaseModel] = EmptyModel
76
+
77
+ task_defaults: TaskDefaults = TaskDefaults()
78
+
79
+ @model_validator(mode="after")
80
+ def validate_concurrency_expression(self) -> "WorkflowConfig":
81
+ if not self.concurrency:
82
+ return self
83
+
84
+ expr = self.concurrency.expression
85
+
86
+ if not expr.startswith("input."):
87
+ return self
88
+
89
+ _, field = expr.split(".", maxsplit=2)
90
+
91
+ if field not in self.input_validator.model_fields.keys():
92
+ raise ValueError(
93
+ f"The concurrency expression provided relies on the `{field}` field, which was not present in `{self.input_validator.__name__}`."
94
+ )
95
+
96
+ return self
97
+
98
+
99
+ class StepType(str, Enum):
100
+ DEFAULT = "default"
101
+ ON_FAILURE = "on_failure"
102
+ ON_SUCCESS = "on_success"
103
+
104
+
105
+ AsyncFunc = Callable[[TWorkflowInput, Context], Awaitable[R]]
106
+ SyncFunc = Callable[[TWorkflowInput, Context], R]
107
+ TaskFunc = Union[AsyncFunc[TWorkflowInput, R], SyncFunc[TWorkflowInput, R]]
108
+
109
+
110
+ def is_async_fn(
111
+ fn: TaskFunc[TWorkflowInput, R]
112
+ ) -> TypeGuard[AsyncFunc[TWorkflowInput, R]]:
113
+ return asyncio.iscoroutinefunction(fn)
114
+
115
+
116
+ def is_sync_fn(
117
+ fn: TaskFunc[TWorkflowInput, R]
118
+ ) -> TypeGuard[SyncFunc[TWorkflowInput, R]]:
119
+ return not asyncio.iscoroutinefunction(fn)
120
+
121
+
122
+ DurableAsyncFunc = Callable[[TWorkflowInput, DurableContext], Awaitable[R]]
123
+ DurableSyncFunc = Callable[[TWorkflowInput, DurableContext], R]
124
+ DurableTaskFunc = Union[
125
+ DurableAsyncFunc[TWorkflowInput, R], DurableSyncFunc[TWorkflowInput, R]
126
+ ]
127
+
128
+
129
+ def is_durable_async_fn(
130
+ fn: Callable[..., Any]
131
+ ) -> TypeGuard[DurableAsyncFunc[TWorkflowInput, R]]:
132
+ return asyncio.iscoroutinefunction(fn)
133
+
134
+
135
+ def is_durable_sync_fn(
136
+ fn: DurableTaskFunc[TWorkflowInput, R]
137
+ ) -> TypeGuard[DurableSyncFunc[TWorkflowInput, R]]:
138
+ return not asyncio.iscoroutinefunction(fn)