hatchet-sdk 0.42.0__py3-none-any.whl → 0.42.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of hatchet-sdk might be problematic. Click here for more details.
- hatchet_sdk/clients/admin.py +10 -8
- hatchet_sdk/clients/dispatcher/action_listener.py +1 -1
- hatchet_sdk/clients/dispatcher/dispatcher.py +2 -2
- hatchet_sdk/clients/rest/tenacity_utils.py +6 -1
- hatchet_sdk/context/context.py +83 -46
- hatchet_sdk/context/worker_context.py +1 -1
- hatchet_sdk/hatchet.py +39 -36
- hatchet_sdk/utils/backoff.py +1 -1
- hatchet_sdk/utils/serialization.py +4 -1
- hatchet_sdk/utils/tracing.py +7 -4
- hatchet_sdk/utils/types.py +8 -0
- hatchet_sdk/utils/typing.py +9 -0
- hatchet_sdk/worker/action_listener_process.py +7 -9
- hatchet_sdk/worker/runner/run_loop_manager.py +15 -9
- hatchet_sdk/worker/runner/runner.py +57 -36
- hatchet_sdk/worker/worker.py +96 -59
- hatchet_sdk/workflow.py +80 -26
- {hatchet_sdk-0.42.0.dist-info → hatchet_sdk-0.42.2.dist-info}/METADATA +1 -1
- {hatchet_sdk-0.42.0.dist-info → hatchet_sdk-0.42.2.dist-info}/RECORD +21 -19
- {hatchet_sdk-0.42.0.dist-info → hatchet_sdk-0.42.2.dist-info}/entry_points.txt +2 -0
- {hatchet_sdk-0.42.0.dist-info → hatchet_sdk-0.42.2.dist-info}/WHEEL +0 -0
hatchet_sdk/clients/admin.py
CHANGED
|
@@ -54,16 +54,10 @@ class ChildTriggerWorkflowOptions(TypedDict):
|
|
|
54
54
|
sticky: bool | None = None
|
|
55
55
|
|
|
56
56
|
|
|
57
|
-
class WorkflowRunDict(TypedDict):
|
|
58
|
-
workflow_name: str
|
|
59
|
-
input: Any
|
|
60
|
-
options: Optional[dict]
|
|
61
|
-
|
|
62
|
-
|
|
63
57
|
class ChildWorkflowRunDict(TypedDict):
|
|
64
58
|
workflow_name: str
|
|
65
59
|
input: Any
|
|
66
|
-
options: ChildTriggerWorkflowOptions
|
|
60
|
+
options: ChildTriggerWorkflowOptions
|
|
67
61
|
key: str | None = None
|
|
68
62
|
|
|
69
63
|
|
|
@@ -73,6 +67,12 @@ class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, TypedDict):
|
|
|
73
67
|
namespace: str | None = None
|
|
74
68
|
|
|
75
69
|
|
|
70
|
+
class WorkflowRunDict(TypedDict):
|
|
71
|
+
workflow_name: str
|
|
72
|
+
input: Any
|
|
73
|
+
options: TriggerWorkflowOptions | None
|
|
74
|
+
|
|
75
|
+
|
|
76
76
|
class DedupeViolationErr(Exception):
|
|
77
77
|
"""Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value."""
|
|
78
78
|
|
|
@@ -260,7 +260,9 @@ class AdminClientAioImpl(AdminClientBase):
|
|
|
260
260
|
|
|
261
261
|
@tenacity_retry
|
|
262
262
|
async def run_workflows(
|
|
263
|
-
self,
|
|
263
|
+
self,
|
|
264
|
+
workflows: list[WorkflowRunDict],
|
|
265
|
+
options: TriggerWorkflowOptions | None = None,
|
|
264
266
|
) -> List[WorkflowRunRef]:
|
|
265
267
|
if len(workflows) == 0:
|
|
266
268
|
raise ValueError("No workflows to run")
|
|
@@ -137,14 +137,14 @@ class DispatcherClient:
|
|
|
137
137
|
|
|
138
138
|
return response
|
|
139
139
|
|
|
140
|
-
def release_slot(self, step_run_id: str):
|
|
140
|
+
def release_slot(self, step_run_id: str) -> None:
|
|
141
141
|
self.client.ReleaseSlot(
|
|
142
142
|
ReleaseSlotRequest(stepRunId=step_run_id),
|
|
143
143
|
timeout=DEFAULT_REGISTER_TIMEOUT,
|
|
144
144
|
metadata=get_metadata(self.token),
|
|
145
145
|
)
|
|
146
146
|
|
|
147
|
-
def refresh_timeout(self, step_run_id: str, increment_by: str):
|
|
147
|
+
def refresh_timeout(self, step_run_id: str, increment_by: str) -> None:
|
|
148
148
|
self.client.RefreshTimeout(
|
|
149
149
|
RefreshTimeoutRequest(
|
|
150
150
|
stepRunId=step_run_id,
|
|
@@ -1,10 +1,15 @@
|
|
|
1
|
+
from typing import Callable, ParamSpec, TypeVar
|
|
2
|
+
|
|
1
3
|
import grpc
|
|
2
4
|
import tenacity
|
|
3
5
|
|
|
4
6
|
from hatchet_sdk.logger import logger
|
|
5
7
|
|
|
8
|
+
P = ParamSpec("P")
|
|
9
|
+
R = TypeVar("R")
|
|
10
|
+
|
|
6
11
|
|
|
7
|
-
def tenacity_retry(func):
|
|
12
|
+
def tenacity_retry(func: Callable[P, R]) -> Callable[P, R]:
|
|
8
13
|
return tenacity.retry(
|
|
9
14
|
reraise=True,
|
|
10
15
|
wait=tenacity.wait_exponential_jitter(),
|
hatchet_sdk/context/context.py
CHANGED
|
@@ -2,7 +2,10 @@ import inspect
|
|
|
2
2
|
import json
|
|
3
3
|
import traceback
|
|
4
4
|
from concurrent.futures import Future, ThreadPoolExecutor
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import Any, Generic, Type, TypeVar, cast, overload
|
|
6
|
+
from warnings import warn
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, StrictStr
|
|
6
9
|
|
|
7
10
|
from hatchet_sdk.clients.events import EventClient
|
|
8
11
|
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
|
|
@@ -10,11 +13,13 @@ from hatchet_sdk.clients.rest_client import RestApi
|
|
|
10
13
|
from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
|
|
11
14
|
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
|
|
12
15
|
from hatchet_sdk.context.worker_context import WorkerContext
|
|
13
|
-
from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData
|
|
14
|
-
from hatchet_sdk.contracts.workflows_pb2 import (
|
|
16
|
+
from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData # type: ignore
|
|
17
|
+
from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
|
|
15
18
|
BulkTriggerWorkflowRequest,
|
|
16
19
|
TriggerWorkflowRequest,
|
|
17
20
|
)
|
|
21
|
+
from hatchet_sdk.utils.types import WorkflowValidator
|
|
22
|
+
from hatchet_sdk.utils.typing import is_basemodel_subclass
|
|
18
23
|
from hatchet_sdk.workflow_run import WorkflowRunRef
|
|
19
24
|
|
|
20
25
|
from ..clients.admin import (
|
|
@@ -24,25 +29,34 @@ from ..clients.admin import (
|
|
|
24
29
|
TriggerWorkflowOptions,
|
|
25
30
|
WorkflowRunDict,
|
|
26
31
|
)
|
|
27
|
-
from ..clients.dispatcher.dispatcher import
|
|
32
|
+
from ..clients.dispatcher.dispatcher import ( # type: ignore[attr-defined]
|
|
33
|
+
Action,
|
|
34
|
+
DispatcherClient,
|
|
35
|
+
)
|
|
28
36
|
from ..logger import logger
|
|
29
37
|
|
|
30
38
|
DEFAULT_WORKFLOW_POLLING_INTERVAL = 5 # Seconds
|
|
31
39
|
|
|
40
|
+
T = TypeVar("T", bound=BaseModel)
|
|
32
41
|
|
|
33
|
-
|
|
42
|
+
|
|
43
|
+
def get_caller_file_path() -> str:
|
|
34
44
|
caller_frame = inspect.stack()[2]
|
|
35
45
|
|
|
36
46
|
return caller_frame.filename
|
|
37
47
|
|
|
38
48
|
|
|
39
49
|
class BaseContext:
|
|
50
|
+
|
|
51
|
+
action: Action
|
|
52
|
+
spawn_index: int
|
|
53
|
+
|
|
40
54
|
def _prepare_workflow_options(
|
|
41
55
|
self,
|
|
42
|
-
key: str = None,
|
|
56
|
+
key: str | None = None,
|
|
43
57
|
options: ChildTriggerWorkflowOptions | None = None,
|
|
44
|
-
worker_id: str = None,
|
|
45
|
-
):
|
|
58
|
+
worker_id: str | None = None,
|
|
59
|
+
) -> TriggerWorkflowOptions:
|
|
46
60
|
workflow_run_id = self.action.workflow_run_id
|
|
47
61
|
step_run_id = self.action.step_run_id
|
|
48
62
|
|
|
@@ -54,7 +68,8 @@ class BaseContext:
|
|
|
54
68
|
if options is not None and "additional_metadata" in options:
|
|
55
69
|
meta = options["additional_metadata"]
|
|
56
70
|
|
|
57
|
-
|
|
71
|
+
## TODO: Pydantic here to simplify this
|
|
72
|
+
trigger_options: TriggerWorkflowOptions = { # type: ignore[typeddict-item]
|
|
58
73
|
"parent_id": workflow_run_id,
|
|
59
74
|
"parent_step_run_id": step_run_id,
|
|
60
75
|
"child_key": key,
|
|
@@ -95,9 +110,9 @@ class ContextAioImpl(BaseContext):
|
|
|
95
110
|
async def spawn_workflow(
|
|
96
111
|
self,
|
|
97
112
|
workflow_name: str,
|
|
98
|
-
input: dict = {},
|
|
99
|
-
key: str = None,
|
|
100
|
-
options: ChildTriggerWorkflowOptions = None,
|
|
113
|
+
input: dict[str, Any] = {},
|
|
114
|
+
key: str | None = None,
|
|
115
|
+
options: ChildTriggerWorkflowOptions | None = None,
|
|
101
116
|
) -> WorkflowRunRef:
|
|
102
117
|
worker_id = self.worker.id()
|
|
103
118
|
# if (
|
|
@@ -118,15 +133,15 @@ class ContextAioImpl(BaseContext):
|
|
|
118
133
|
|
|
119
134
|
@tenacity_retry
|
|
120
135
|
async def spawn_workflows(
|
|
121
|
-
self, child_workflow_runs:
|
|
122
|
-
) ->
|
|
136
|
+
self, child_workflow_runs: list[ChildWorkflowRunDict]
|
|
137
|
+
) -> list[WorkflowRunRef]:
|
|
123
138
|
|
|
124
139
|
if len(child_workflow_runs) == 0:
|
|
125
140
|
raise Exception("no child workflows to spawn")
|
|
126
141
|
|
|
127
142
|
worker_id = self.worker.id()
|
|
128
143
|
|
|
129
|
-
bulk_trigger_workflow_runs: WorkflowRunDict = []
|
|
144
|
+
bulk_trigger_workflow_runs: list[WorkflowRunDict] = []
|
|
130
145
|
for child_workflow_run in child_workflow_runs:
|
|
131
146
|
workflow_name = child_workflow_run["workflow_name"]
|
|
132
147
|
input = child_workflow_run["input"]
|
|
@@ -134,7 +149,8 @@ class ContextAioImpl(BaseContext):
|
|
|
134
149
|
key = child_workflow_run.get("key")
|
|
135
150
|
options = child_workflow_run.get("options", {})
|
|
136
151
|
|
|
137
|
-
|
|
152
|
+
## TODO: figure out why this is failing
|
|
153
|
+
trigger_options = self._prepare_workflow_options(key, options, worker_id) # type: ignore[arg-type]
|
|
138
154
|
|
|
139
155
|
bulk_trigger_workflow_runs.append(
|
|
140
156
|
WorkflowRunDict(
|
|
@@ -161,8 +177,10 @@ class Context(BaseContext):
|
|
|
161
177
|
workflow_run_event_listener: RunEventListenerClient,
|
|
162
178
|
worker: WorkerContext,
|
|
163
179
|
namespace: str = "",
|
|
180
|
+
validator_registry: dict[str, WorkflowValidator] = {},
|
|
164
181
|
):
|
|
165
182
|
self.worker = worker
|
|
183
|
+
self.validator_registry = validator_registry
|
|
166
184
|
|
|
167
185
|
self.aio = ContextAioImpl(
|
|
168
186
|
action,
|
|
@@ -179,11 +197,11 @@ class Context(BaseContext):
|
|
|
179
197
|
# Check the type of action.action_payload before attempting to load it as JSON
|
|
180
198
|
if isinstance(action.action_payload, (str, bytes, bytearray)):
|
|
181
199
|
try:
|
|
182
|
-
self.data = json.loads(action.action_payload)
|
|
200
|
+
self.data = cast(dict[str, Any], json.loads(action.action_payload))
|
|
183
201
|
except Exception as e:
|
|
184
202
|
logger.error(f"Error parsing action payload: {e}")
|
|
185
203
|
# Assign an empty dictionary if parsing fails
|
|
186
|
-
self.data = {}
|
|
204
|
+
self.data: dict[str, Any] = {} # type: ignore[no-redef]
|
|
187
205
|
else:
|
|
188
206
|
# Directly assign the payload to self.data if it's already a dict
|
|
189
207
|
self.data = (
|
|
@@ -191,6 +209,7 @@ class Context(BaseContext):
|
|
|
191
209
|
)
|
|
192
210
|
|
|
193
211
|
self.action = action
|
|
212
|
+
|
|
194
213
|
# FIXME: stepRunId is a legacy field, we should remove it
|
|
195
214
|
self.stepRunId = action.step_run_id
|
|
196
215
|
|
|
@@ -218,33 +237,53 @@ class Context(BaseContext):
|
|
|
218
237
|
else:
|
|
219
238
|
self.input = self.data.get("input", {})
|
|
220
239
|
|
|
221
|
-
def step_output(self, step: str):
|
|
240
|
+
def step_output(self, step: str) -> dict[str, Any] | BaseModel:
|
|
241
|
+
validators = self.validator_registry.get(step)
|
|
242
|
+
|
|
222
243
|
try:
|
|
223
|
-
|
|
244
|
+
parent_step_data = cast(dict[str, Any], self.data["parents"][step])
|
|
224
245
|
except KeyError:
|
|
225
246
|
raise ValueError(f"Step output for '{step}' not found")
|
|
226
247
|
|
|
248
|
+
if validators and (v := validators.step_output):
|
|
249
|
+
return v.model_validate(parent_step_data)
|
|
250
|
+
|
|
251
|
+
return parent_step_data
|
|
252
|
+
|
|
227
253
|
def triggered_by_event(self) -> bool:
|
|
228
|
-
return self.data.get("triggered_by", "") == "event"
|
|
254
|
+
return cast(str, self.data.get("triggered_by", "")) == "event"
|
|
255
|
+
|
|
256
|
+
def workflow_input(self) -> dict[str, Any] | T:
|
|
257
|
+
if (r := self.validator_registry.get(self.action.action_id)) and (
|
|
258
|
+
i := r.workflow_input
|
|
259
|
+
):
|
|
260
|
+
return cast(
|
|
261
|
+
T,
|
|
262
|
+
i.model_validate(self.input),
|
|
263
|
+
)
|
|
229
264
|
|
|
230
|
-
def workflow_input(self):
|
|
231
265
|
return self.input
|
|
232
266
|
|
|
233
|
-
def workflow_run_id(self):
|
|
267
|
+
def workflow_run_id(self) -> str:
|
|
234
268
|
return self.action.workflow_run_id
|
|
235
269
|
|
|
236
|
-
def cancel(self):
|
|
270
|
+
def cancel(self) -> None:
|
|
237
271
|
logger.debug("cancelling step...")
|
|
238
272
|
self.exit_flag = True
|
|
239
273
|
|
|
240
274
|
# done returns true if the context has been cancelled
|
|
241
|
-
def done(self):
|
|
275
|
+
def done(self) -> bool:
|
|
242
276
|
return self.exit_flag
|
|
243
277
|
|
|
244
|
-
def playground(self, name: str, default: str = None):
|
|
278
|
+
def playground(self, name: str, default: str | None = None) -> str | None:
|
|
245
279
|
# if the key exists in the overrides_data field, return the value
|
|
246
280
|
if name in self.overrides_data:
|
|
247
|
-
|
|
281
|
+
warn(
|
|
282
|
+
"Use of `overrides_data` is deprecated.",
|
|
283
|
+
DeprecationWarning,
|
|
284
|
+
stacklevel=1,
|
|
285
|
+
)
|
|
286
|
+
return str(self.overrides_data[name])
|
|
248
287
|
|
|
249
288
|
caller_file = get_caller_file_path()
|
|
250
289
|
|
|
@@ -259,7 +298,7 @@ class Context(BaseContext):
|
|
|
259
298
|
|
|
260
299
|
return default
|
|
261
300
|
|
|
262
|
-
def _log(self, line: str) ->
|
|
301
|
+
def _log(self, line: str) -> tuple[bool, Exception | None]:
|
|
263
302
|
try:
|
|
264
303
|
self.event_client.log(message=line, step_run_id=self.stepRunId)
|
|
265
304
|
return True, None
|
|
@@ -267,7 +306,7 @@ class Context(BaseContext):
|
|
|
267
306
|
# we don't want to raise an exception here, as it will kill the log thread
|
|
268
307
|
return False, e
|
|
269
308
|
|
|
270
|
-
def log(self, line, raise_on_error: bool = False):
|
|
309
|
+
def log(self, line: Any, raise_on_error: bool = False) -> None:
|
|
271
310
|
if self.stepRunId == "":
|
|
272
311
|
return
|
|
273
312
|
|
|
@@ -277,9 +316,9 @@ class Context(BaseContext):
|
|
|
277
316
|
except Exception:
|
|
278
317
|
line = str(line)
|
|
279
318
|
|
|
280
|
-
future
|
|
319
|
+
future = self.logger_thread_pool.submit(self._log, line)
|
|
281
320
|
|
|
282
|
-
def handle_result(future: Future):
|
|
321
|
+
def handle_result(future: Future[tuple[bool, Exception | None]]) -> None:
|
|
283
322
|
success, exception = future.result()
|
|
284
323
|
if not success and exception:
|
|
285
324
|
if raise_on_error:
|
|
@@ -297,22 +336,22 @@ class Context(BaseContext):
|
|
|
297
336
|
|
|
298
337
|
future.add_done_callback(handle_result)
|
|
299
338
|
|
|
300
|
-
def release_slot(self):
|
|
339
|
+
def release_slot(self) -> None:
|
|
301
340
|
return self.dispatcher_client.release_slot(self.stepRunId)
|
|
302
341
|
|
|
303
|
-
def _put_stream(self, data: str | bytes):
|
|
342
|
+
def _put_stream(self, data: str | bytes) -> None:
|
|
304
343
|
try:
|
|
305
344
|
self.event_client.stream(data=data, step_run_id=self.stepRunId)
|
|
306
345
|
except Exception as e:
|
|
307
346
|
logger.error(f"Error putting stream event: {e}")
|
|
308
347
|
|
|
309
|
-
def put_stream(self, data: str | bytes):
|
|
348
|
+
def put_stream(self, data: str | bytes) -> None:
|
|
310
349
|
if self.stepRunId == "":
|
|
311
350
|
return
|
|
312
351
|
|
|
313
352
|
self.stream_event_thread_pool.submit(self._put_stream, data)
|
|
314
353
|
|
|
315
|
-
def refresh_timeout(self, increment_by: str):
|
|
354
|
+
def refresh_timeout(self, increment_by: str) -> None:
|
|
316
355
|
try:
|
|
317
356
|
return self.dispatcher_client.refresh_timeout(
|
|
318
357
|
step_run_id=self.stepRunId, increment_by=increment_by
|
|
@@ -320,28 +359,28 @@ class Context(BaseContext):
|
|
|
320
359
|
except Exception as e:
|
|
321
360
|
logger.error(f"Error refreshing timeout: {e}")
|
|
322
361
|
|
|
323
|
-
def retry_count(self):
|
|
362
|
+
def retry_count(self) -> int:
|
|
324
363
|
return self.action.retry_count
|
|
325
364
|
|
|
326
|
-
def additional_metadata(self):
|
|
365
|
+
def additional_metadata(self) -> dict[str, Any] | None:
|
|
327
366
|
return self.action.additional_metadata
|
|
328
367
|
|
|
329
|
-
def child_index(self):
|
|
368
|
+
def child_index(self) -> int | None:
|
|
330
369
|
return self.action.child_workflow_index
|
|
331
370
|
|
|
332
|
-
def child_key(self):
|
|
371
|
+
def child_key(self) -> str | None:
|
|
333
372
|
return self.action.child_workflow_key
|
|
334
373
|
|
|
335
|
-
def parent_workflow_run_id(self):
|
|
374
|
+
def parent_workflow_run_id(self) -> str | None:
|
|
336
375
|
return self.action.parent_workflow_run_id
|
|
337
376
|
|
|
338
|
-
def fetch_run_failures(self):
|
|
377
|
+
def fetch_run_failures(self) -> list[dict[str, StrictStr]]:
|
|
339
378
|
data = self.rest_client.workflow_run_get(self.action.workflow_run_id)
|
|
340
379
|
other_job_runs = [
|
|
341
|
-
run for run in data.job_runs if run.job_id != self.action.job_id
|
|
380
|
+
run for run in (data.job_runs or []) if run.job_id != self.action.job_id
|
|
342
381
|
]
|
|
343
382
|
# TODO: Parse Step Runs using a Pydantic Model rather than a hand crafted dictionary
|
|
344
|
-
|
|
383
|
+
return [
|
|
345
384
|
{
|
|
346
385
|
"step_id": step_run.step_id,
|
|
347
386
|
"step_run_action_name": step_run.step.action,
|
|
@@ -350,7 +389,5 @@ class Context(BaseContext):
|
|
|
350
389
|
for job_run in other_job_runs
|
|
351
390
|
if job_run.step_runs
|
|
352
391
|
for step_run in job_run.step_runs
|
|
353
|
-
if step_run.error
|
|
392
|
+
if step_run.error and step_run.step
|
|
354
393
|
]
|
|
355
|
-
|
|
356
|
-
return failed_step_runs
|
hatchet_sdk/hatchet.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import logging
|
|
3
|
-
from typing import Any, Callable, Optional,
|
|
3
|
+
from typing import Any, Callable, Optional, Type, TypeVar, cast, get_type_hints
|
|
4
4
|
|
|
5
|
+
from pydantic import BaseModel
|
|
5
6
|
from typing_extensions import deprecated
|
|
6
7
|
|
|
7
8
|
from hatchet_sdk.clients.rest_client import RestApi
|
|
@@ -27,14 +28,17 @@ from .clients.events import EventClient
|
|
|
27
28
|
from .clients.run_event_listener import RunEventListenerClient
|
|
28
29
|
from .logger import logger
|
|
29
30
|
from .worker.worker import Worker
|
|
30
|
-
from .workflow import
|
|
31
|
+
from .workflow import (
|
|
32
|
+
ConcurrencyExpression,
|
|
33
|
+
WorkflowInterface,
|
|
34
|
+
WorkflowMeta,
|
|
35
|
+
WorkflowStepProtocol,
|
|
36
|
+
)
|
|
31
37
|
|
|
32
|
-
|
|
33
|
-
R = TypeVar("R")
|
|
38
|
+
T = TypeVar("T", bound=BaseModel)
|
|
34
39
|
|
|
35
40
|
|
|
36
|
-
|
|
37
|
-
def workflow( # type: ignore[no-untyped-def]
|
|
41
|
+
def workflow(
|
|
38
42
|
name: str = "",
|
|
39
43
|
on_events: list[str] | None = None,
|
|
40
44
|
on_crons: list[str] | None = None,
|
|
@@ -44,11 +48,12 @@ def workflow( # type: ignore[no-untyped-def]
|
|
|
44
48
|
sticky: StickyStrategy = None,
|
|
45
49
|
default_priority: int | None = None,
|
|
46
50
|
concurrency: ConcurrencyExpression | None = None,
|
|
47
|
-
|
|
51
|
+
input_validator: Type[T] | None = None,
|
|
52
|
+
) -> Callable[[Type[WorkflowInterface]], WorkflowMeta]:
|
|
48
53
|
on_events = on_events or []
|
|
49
54
|
on_crons = on_crons or []
|
|
50
55
|
|
|
51
|
-
def inner(cls:
|
|
56
|
+
def inner(cls: Type[WorkflowInterface]) -> WorkflowMeta:
|
|
52
57
|
cls.on_events = on_events
|
|
53
58
|
cls.on_crons = on_crons
|
|
54
59
|
cls.name = name or str(cls.__name__)
|
|
@@ -62,7 +67,8 @@ def workflow( # type: ignore[no-untyped-def]
|
|
|
62
67
|
# with WorkflowMeta as its metaclass
|
|
63
68
|
|
|
64
69
|
## TODO: Figure out how to type this metaclass correctly
|
|
65
|
-
|
|
70
|
+
cls.input_validator = input_validator
|
|
71
|
+
return WorkflowMeta(cls.name, cls.__bases__, dict(cls.__dict__))
|
|
66
72
|
|
|
67
73
|
return inner
|
|
68
74
|
|
|
@@ -76,10 +82,10 @@ def step(
|
|
|
76
82
|
desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
|
|
77
83
|
backoff_factor: float | None = None,
|
|
78
84
|
backoff_max_seconds: int | None = None,
|
|
79
|
-
) -> Callable[[
|
|
85
|
+
) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]:
|
|
80
86
|
parents = parents or []
|
|
81
87
|
|
|
82
|
-
def inner(func:
|
|
88
|
+
def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol:
|
|
83
89
|
limits = None
|
|
84
90
|
if rate_limits:
|
|
85
91
|
limits = [
|
|
@@ -87,20 +93,19 @@ def step(
|
|
|
87
93
|
for rate_limit in rate_limits or []
|
|
88
94
|
]
|
|
89
95
|
|
|
90
|
-
|
|
91
|
-
func.
|
|
92
|
-
func.
|
|
93
|
-
func.
|
|
94
|
-
func.
|
|
95
|
-
func.
|
|
96
|
-
func.
|
|
97
|
-
func._step_backoff_max_seconds = backoff_max_seconds # type: ignore[attr-defined]
|
|
96
|
+
func._step_name = name.lower() or str(func.__name__).lower()
|
|
97
|
+
func._step_parents = parents
|
|
98
|
+
func._step_timeout = timeout
|
|
99
|
+
func._step_retries = retries
|
|
100
|
+
func._step_rate_limits = limits
|
|
101
|
+
func._step_backoff_factor = backoff_factor
|
|
102
|
+
func._step_backoff_max_seconds = backoff_max_seconds
|
|
98
103
|
|
|
99
|
-
func._step_desired_worker_labels = {}
|
|
104
|
+
func._step_desired_worker_labels = {}
|
|
100
105
|
|
|
101
106
|
for key, d in desired_worker_labels.items():
|
|
102
107
|
value = d["value"] if "value" in d else None
|
|
103
|
-
func._step_desired_worker_labels[key] = DesiredWorkerLabels(
|
|
108
|
+
func._step_desired_worker_labels[key] = DesiredWorkerLabels(
|
|
104
109
|
strValue=str(value) if not isinstance(value, int) else None,
|
|
105
110
|
intValue=value if isinstance(value, int) else None,
|
|
106
111
|
required=d["required"] if "required" in d else None,
|
|
@@ -120,8 +125,8 @@ def on_failure_step(
|
|
|
120
125
|
rate_limits: list[RateLimit] | None = None,
|
|
121
126
|
backoff_factor: float | None = None,
|
|
122
127
|
backoff_max_seconds: int | None = None,
|
|
123
|
-
) -> Callable[[
|
|
124
|
-
def inner(func:
|
|
128
|
+
) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]:
|
|
129
|
+
def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol:
|
|
125
130
|
limits = None
|
|
126
131
|
if rate_limits:
|
|
127
132
|
limits = [
|
|
@@ -129,13 +134,12 @@ def on_failure_step(
|
|
|
129
134
|
for rate_limit in rate_limits or []
|
|
130
135
|
]
|
|
131
136
|
|
|
132
|
-
|
|
133
|
-
func.
|
|
134
|
-
func.
|
|
135
|
-
func.
|
|
136
|
-
func.
|
|
137
|
-
func.
|
|
138
|
-
func._on_failure_step_backoff_max_seconds = backoff_max_seconds # type: ignore[attr-defined]
|
|
137
|
+
func._on_failure_step_name = name.lower() or str(func.__name__).lower()
|
|
138
|
+
func._on_failure_step_timeout = timeout
|
|
139
|
+
func._on_failure_step_retries = retries
|
|
140
|
+
func._on_failure_step_rate_limits = limits
|
|
141
|
+
func._on_failure_step_backoff_factor = backoff_factor
|
|
142
|
+
func._on_failure_step_backoff_max_seconds = backoff_max_seconds
|
|
139
143
|
|
|
140
144
|
return func
|
|
141
145
|
|
|
@@ -146,12 +150,11 @@ def concurrency(
|
|
|
146
150
|
name: str = "",
|
|
147
151
|
max_runs: int = 1,
|
|
148
152
|
limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS,
|
|
149
|
-
) -> Callable[[
|
|
150
|
-
def inner(func:
|
|
151
|
-
|
|
152
|
-
func.
|
|
153
|
-
func.
|
|
154
|
-
func._concurrency_limit_strategy = limit_strategy # type: ignore[attr-defined]
|
|
153
|
+
) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]:
|
|
154
|
+
def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol:
|
|
155
|
+
func._concurrency_fn_name = name.lower() or str(func.__name__).lower()
|
|
156
|
+
func._concurrency_max_runs = max_runs
|
|
157
|
+
func._concurrency_limit_strategy = limit_strategy
|
|
155
158
|
|
|
156
159
|
return func
|
|
157
160
|
|
hatchet_sdk/utils/backoff.py
CHANGED
|
@@ -2,7 +2,7 @@ import asyncio
|
|
|
2
2
|
import random
|
|
3
3
|
|
|
4
4
|
|
|
5
|
-
async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5):
|
|
5
|
+
async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5) -> None:
|
|
6
6
|
base_time = 0.1 # starting sleep time in seconds (100 milliseconds)
|
|
7
7
|
jitter = random.uniform(0, base_time) # add random jitter
|
|
8
8
|
sleep_time = min(base_time * (2**attempt) + jitter, max_sleep_time)
|
|
@@ -2,7 +2,10 @@ from typing import Any
|
|
|
2
2
|
|
|
3
3
|
|
|
4
4
|
def flatten(xs: dict[str, Any], parent_key: str, separator: str) -> dict[str, Any]:
|
|
5
|
-
|
|
5
|
+
if not xs:
|
|
6
|
+
return {}
|
|
7
|
+
|
|
8
|
+
items: list[tuple[str, Any]] = []
|
|
6
9
|
|
|
7
10
|
for k, v in xs.items():
|
|
8
11
|
new_key = parent_key + separator + k if parent_key else k
|
hatchet_sdk/utils/tracing.py
CHANGED
|
@@ -6,9 +6,9 @@ from opentelemetry import trace
|
|
|
6
6
|
from opentelemetry.context import Context
|
|
7
7
|
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
|
8
8
|
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
|
|
9
|
-
from opentelemetry.sdk.trace import
|
|
9
|
+
from opentelemetry.sdk.trace import TracerProvider
|
|
10
10
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
|
11
|
-
from opentelemetry.trace import NoOpTracerProvider
|
|
11
|
+
from opentelemetry.trace import NoOpTracerProvider, Tracer
|
|
12
12
|
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
|
|
13
13
|
|
|
14
14
|
from hatchet_sdk.loader import ClientConfig
|
|
@@ -44,7 +44,7 @@ def create_tracer(config: ClientConfig) -> Tracer:
|
|
|
44
44
|
|
|
45
45
|
|
|
46
46
|
def create_carrier() -> dict[str, str]:
|
|
47
|
-
carrier = {}
|
|
47
|
+
carrier: dict[str, str] = {}
|
|
48
48
|
TraceContextTextMapPropagator().inject(carrier)
|
|
49
49
|
|
|
50
50
|
return carrier
|
|
@@ -59,7 +59,10 @@ def inject_carrier_into_metadata(
|
|
|
59
59
|
return metadata
|
|
60
60
|
|
|
61
61
|
|
|
62
|
-
def parse_carrier_from_metadata(metadata: dict[str, Any]) -> Context:
|
|
62
|
+
def parse_carrier_from_metadata(metadata: dict[str, Any] | None) -> Context | None:
|
|
63
|
+
if not metadata:
|
|
64
|
+
return None
|
|
65
|
+
|
|
63
66
|
return (
|
|
64
67
|
TraceContextTextMapPropagator().extract(_ctx)
|
|
65
68
|
if (_ctx := metadata.get(OTEL_CARRIER_KEY))
|
|
@@ -87,15 +87,13 @@ class WorkerActionListenerProcess:
|
|
|
87
87
|
try:
|
|
88
88
|
self.dispatcher_client = new_dispatcher(self.config)
|
|
89
89
|
|
|
90
|
-
self.listener
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
_labels=self.labels,
|
|
98
|
-
)
|
|
90
|
+
self.listener = await self.dispatcher_client.get_action_listener(
|
|
91
|
+
GetActionListenerRequest(
|
|
92
|
+
worker_name=self.name,
|
|
93
|
+
services=["default"],
|
|
94
|
+
actions=self.actions,
|
|
95
|
+
max_runs=self.max_runs,
|
|
96
|
+
_labels=self.labels,
|
|
99
97
|
)
|
|
100
98
|
)
|
|
101
99
|
|