hatchet-sdk 0.42.0__py3-none-any.whl → 0.42.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

@@ -54,16 +54,10 @@ class ChildTriggerWorkflowOptions(TypedDict):
54
54
  sticky: bool | None = None
55
55
 
56
56
 
57
- class WorkflowRunDict(TypedDict):
58
- workflow_name: str
59
- input: Any
60
- options: Optional[dict]
61
-
62
-
63
57
  class ChildWorkflowRunDict(TypedDict):
64
58
  workflow_name: str
65
59
  input: Any
66
- options: ChildTriggerWorkflowOptions[dict]
60
+ options: ChildTriggerWorkflowOptions
67
61
  key: str | None = None
68
62
 
69
63
 
@@ -73,6 +67,12 @@ class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, TypedDict):
73
67
  namespace: str | None = None
74
68
 
75
69
 
70
+ class WorkflowRunDict(TypedDict):
71
+ workflow_name: str
72
+ input: Any
73
+ options: TriggerWorkflowOptions | None
74
+
75
+
76
76
  class DedupeViolationErr(Exception):
77
77
  """Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value."""
78
78
 
@@ -260,7 +260,9 @@ class AdminClientAioImpl(AdminClientBase):
260
260
 
261
261
  @tenacity_retry
262
262
  async def run_workflows(
263
- self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
263
+ self,
264
+ workflows: list[WorkflowRunDict],
265
+ options: TriggerWorkflowOptions | None = None,
264
266
  ) -> List[WorkflowRunRef]:
265
267
  if len(workflows) == 0:
266
268
  raise ValueError("No workflows to run")
@@ -61,7 +61,7 @@ class Action:
61
61
  worker_id: str
62
62
  tenant_id: str
63
63
  workflow_run_id: str
64
- get_group_key_run_id: Optional[str]
64
+ get_group_key_run_id: str
65
65
  job_id: str
66
66
  job_name: str
67
67
  job_run_id: str
@@ -137,14 +137,14 @@ class DispatcherClient:
137
137
 
138
138
  return response
139
139
 
140
- def release_slot(self, step_run_id: str):
140
+ def release_slot(self, step_run_id: str) -> None:
141
141
  self.client.ReleaseSlot(
142
142
  ReleaseSlotRequest(stepRunId=step_run_id),
143
143
  timeout=DEFAULT_REGISTER_TIMEOUT,
144
144
  metadata=get_metadata(self.token),
145
145
  )
146
146
 
147
- def refresh_timeout(self, step_run_id: str, increment_by: str):
147
+ def refresh_timeout(self, step_run_id: str, increment_by: str) -> None:
148
148
  self.client.RefreshTimeout(
149
149
  RefreshTimeoutRequest(
150
150
  stepRunId=step_run_id,
@@ -1,10 +1,15 @@
1
+ from typing import Callable, ParamSpec, TypeVar
2
+
1
3
  import grpc
2
4
  import tenacity
3
5
 
4
6
  from hatchet_sdk.logger import logger
5
7
 
8
+ P = ParamSpec("P")
9
+ R = TypeVar("R")
10
+
6
11
 
7
- def tenacity_retry(func):
12
+ def tenacity_retry(func: Callable[P, R]) -> Callable[P, R]:
8
13
  return tenacity.retry(
9
14
  reraise=True,
10
15
  wait=tenacity.wait_exponential_jitter(),
@@ -2,7 +2,10 @@ import inspect
2
2
  import json
3
3
  import traceback
4
4
  from concurrent.futures import Future, ThreadPoolExecutor
5
- from typing import List
5
+ from typing import Any, Generic, Type, TypeVar, cast, overload
6
+ from warnings import warn
7
+
8
+ from pydantic import BaseModel, StrictStr
6
9
 
7
10
  from hatchet_sdk.clients.events import EventClient
8
11
  from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
@@ -10,11 +13,13 @@ from hatchet_sdk.clients.rest_client import RestApi
10
13
  from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
11
14
  from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
12
15
  from hatchet_sdk.context.worker_context import WorkerContext
13
- from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData
14
- from hatchet_sdk.contracts.workflows_pb2 import (
16
+ from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData # type: ignore
17
+ from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
15
18
  BulkTriggerWorkflowRequest,
16
19
  TriggerWorkflowRequest,
17
20
  )
21
+ from hatchet_sdk.utils.types import WorkflowValidator
22
+ from hatchet_sdk.utils.typing import is_basemodel_subclass
18
23
  from hatchet_sdk.workflow_run import WorkflowRunRef
19
24
 
20
25
  from ..clients.admin import (
@@ -24,25 +29,34 @@ from ..clients.admin import (
24
29
  TriggerWorkflowOptions,
25
30
  WorkflowRunDict,
26
31
  )
27
- from ..clients.dispatcher.dispatcher import Action, DispatcherClient
32
+ from ..clients.dispatcher.dispatcher import ( # type: ignore[attr-defined]
33
+ Action,
34
+ DispatcherClient,
35
+ )
28
36
  from ..logger import logger
29
37
 
30
38
  DEFAULT_WORKFLOW_POLLING_INTERVAL = 5 # Seconds
31
39
 
40
+ T = TypeVar("T", bound=BaseModel)
32
41
 
33
- def get_caller_file_path():
42
+
43
+ def get_caller_file_path() -> str:
34
44
  caller_frame = inspect.stack()[2]
35
45
 
36
46
  return caller_frame.filename
37
47
 
38
48
 
39
49
  class BaseContext:
50
+
51
+ action: Action
52
+ spawn_index: int
53
+
40
54
  def _prepare_workflow_options(
41
55
  self,
42
- key: str = None,
56
+ key: str | None = None,
43
57
  options: ChildTriggerWorkflowOptions | None = None,
44
- worker_id: str = None,
45
- ):
58
+ worker_id: str | None = None,
59
+ ) -> TriggerWorkflowOptions:
46
60
  workflow_run_id = self.action.workflow_run_id
47
61
  step_run_id = self.action.step_run_id
48
62
 
@@ -54,7 +68,8 @@ class BaseContext:
54
68
  if options is not None and "additional_metadata" in options:
55
69
  meta = options["additional_metadata"]
56
70
 
57
- trigger_options: TriggerWorkflowOptions = {
71
+ ## TODO: Pydantic here to simplify this
72
+ trigger_options: TriggerWorkflowOptions = { # type: ignore[typeddict-item]
58
73
  "parent_id": workflow_run_id,
59
74
  "parent_step_run_id": step_run_id,
60
75
  "child_key": key,
@@ -95,9 +110,9 @@ class ContextAioImpl(BaseContext):
95
110
  async def spawn_workflow(
96
111
  self,
97
112
  workflow_name: str,
98
- input: dict = {},
99
- key: str = None,
100
- options: ChildTriggerWorkflowOptions = None,
113
+ input: dict[str, Any] = {},
114
+ key: str | None = None,
115
+ options: ChildTriggerWorkflowOptions | None = None,
101
116
  ) -> WorkflowRunRef:
102
117
  worker_id = self.worker.id()
103
118
  # if (
@@ -118,15 +133,15 @@ class ContextAioImpl(BaseContext):
118
133
 
119
134
  @tenacity_retry
120
135
  async def spawn_workflows(
121
- self, child_workflow_runs: List[ChildWorkflowRunDict]
122
- ) -> List[WorkflowRunRef]:
136
+ self, child_workflow_runs: list[ChildWorkflowRunDict]
137
+ ) -> list[WorkflowRunRef]:
123
138
 
124
139
  if len(child_workflow_runs) == 0:
125
140
  raise Exception("no child workflows to spawn")
126
141
 
127
142
  worker_id = self.worker.id()
128
143
 
129
- bulk_trigger_workflow_runs: WorkflowRunDict = []
144
+ bulk_trigger_workflow_runs: list[WorkflowRunDict] = []
130
145
  for child_workflow_run in child_workflow_runs:
131
146
  workflow_name = child_workflow_run["workflow_name"]
132
147
  input = child_workflow_run["input"]
@@ -134,7 +149,8 @@ class ContextAioImpl(BaseContext):
134
149
  key = child_workflow_run.get("key")
135
150
  options = child_workflow_run.get("options", {})
136
151
 
137
- trigger_options = self._prepare_workflow_options(key, options, worker_id)
152
+ ## TODO: figure out why this is failing
153
+ trigger_options = self._prepare_workflow_options(key, options, worker_id) # type: ignore[arg-type]
138
154
 
139
155
  bulk_trigger_workflow_runs.append(
140
156
  WorkflowRunDict(
@@ -161,8 +177,10 @@ class Context(BaseContext):
161
177
  workflow_run_event_listener: RunEventListenerClient,
162
178
  worker: WorkerContext,
163
179
  namespace: str = "",
180
+ validator_registry: dict[str, WorkflowValidator] = {},
164
181
  ):
165
182
  self.worker = worker
183
+ self.validator_registry = validator_registry
166
184
 
167
185
  self.aio = ContextAioImpl(
168
186
  action,
@@ -179,11 +197,11 @@ class Context(BaseContext):
179
197
  # Check the type of action.action_payload before attempting to load it as JSON
180
198
  if isinstance(action.action_payload, (str, bytes, bytearray)):
181
199
  try:
182
- self.data = json.loads(action.action_payload)
200
+ self.data = cast(dict[str, Any], json.loads(action.action_payload))
183
201
  except Exception as e:
184
202
  logger.error(f"Error parsing action payload: {e}")
185
203
  # Assign an empty dictionary if parsing fails
186
- self.data = {}
204
+ self.data: dict[str, Any] = {} # type: ignore[no-redef]
187
205
  else:
188
206
  # Directly assign the payload to self.data if it's already a dict
189
207
  self.data = (
@@ -191,6 +209,7 @@ class Context(BaseContext):
191
209
  )
192
210
 
193
211
  self.action = action
212
+
194
213
  # FIXME: stepRunId is a legacy field, we should remove it
195
214
  self.stepRunId = action.step_run_id
196
215
 
@@ -218,33 +237,53 @@ class Context(BaseContext):
218
237
  else:
219
238
  self.input = self.data.get("input", {})
220
239
 
221
- def step_output(self, step: str):
240
+ def step_output(self, step: str) -> dict[str, Any] | BaseModel:
241
+ validators = self.validator_registry.get(step)
242
+
222
243
  try:
223
- return self.data["parents"][step]
244
+ parent_step_data = cast(dict[str, Any], self.data["parents"][step])
224
245
  except KeyError:
225
246
  raise ValueError(f"Step output for '{step}' not found")
226
247
 
248
+ if validators and (v := validators.step_output):
249
+ return v.model_validate(parent_step_data)
250
+
251
+ return parent_step_data
252
+
227
253
  def triggered_by_event(self) -> bool:
228
- return self.data.get("triggered_by", "") == "event"
254
+ return cast(str, self.data.get("triggered_by", "")) == "event"
255
+
256
+ def workflow_input(self) -> dict[str, Any] | T:
257
+ if (r := self.validator_registry.get(self.action.action_id)) and (
258
+ i := r.workflow_input
259
+ ):
260
+ return cast(
261
+ T,
262
+ i.model_validate(self.input),
263
+ )
229
264
 
230
- def workflow_input(self):
231
265
  return self.input
232
266
 
233
- def workflow_run_id(self):
267
+ def workflow_run_id(self) -> str:
234
268
  return self.action.workflow_run_id
235
269
 
236
- def cancel(self):
270
+ def cancel(self) -> None:
237
271
  logger.debug("cancelling step...")
238
272
  self.exit_flag = True
239
273
 
240
274
  # done returns true if the context has been cancelled
241
- def done(self):
275
+ def done(self) -> bool:
242
276
  return self.exit_flag
243
277
 
244
- def playground(self, name: str, default: str = None):
278
+ def playground(self, name: str, default: str | None = None) -> str | None:
245
279
  # if the key exists in the overrides_data field, return the value
246
280
  if name in self.overrides_data:
247
- return self.overrides_data[name]
281
+ warn(
282
+ "Use of `overrides_data` is deprecated.",
283
+ DeprecationWarning,
284
+ stacklevel=1,
285
+ )
286
+ return str(self.overrides_data[name])
248
287
 
249
288
  caller_file = get_caller_file_path()
250
289
 
@@ -259,7 +298,7 @@ class Context(BaseContext):
259
298
 
260
299
  return default
261
300
 
262
- def _log(self, line: str) -> (bool, Exception): # type: ignore
301
+ def _log(self, line: str) -> tuple[bool, Exception | None]:
263
302
  try:
264
303
  self.event_client.log(message=line, step_run_id=self.stepRunId)
265
304
  return True, None
@@ -267,7 +306,7 @@ class Context(BaseContext):
267
306
  # we don't want to raise an exception here, as it will kill the log thread
268
307
  return False, e
269
308
 
270
- def log(self, line, raise_on_error: bool = False):
309
+ def log(self, line: Any, raise_on_error: bool = False) -> None:
271
310
  if self.stepRunId == "":
272
311
  return
273
312
 
@@ -277,9 +316,9 @@ class Context(BaseContext):
277
316
  except Exception:
278
317
  line = str(line)
279
318
 
280
- future: Future = self.logger_thread_pool.submit(self._log, line)
319
+ future = self.logger_thread_pool.submit(self._log, line)
281
320
 
282
- def handle_result(future: Future):
321
+ def handle_result(future: Future[tuple[bool, Exception | None]]) -> None:
283
322
  success, exception = future.result()
284
323
  if not success and exception:
285
324
  if raise_on_error:
@@ -297,22 +336,22 @@ class Context(BaseContext):
297
336
 
298
337
  future.add_done_callback(handle_result)
299
338
 
300
- def release_slot(self):
339
+ def release_slot(self) -> None:
301
340
  return self.dispatcher_client.release_slot(self.stepRunId)
302
341
 
303
- def _put_stream(self, data: str | bytes):
342
+ def _put_stream(self, data: str | bytes) -> None:
304
343
  try:
305
344
  self.event_client.stream(data=data, step_run_id=self.stepRunId)
306
345
  except Exception as e:
307
346
  logger.error(f"Error putting stream event: {e}")
308
347
 
309
- def put_stream(self, data: str | bytes):
348
+ def put_stream(self, data: str | bytes) -> None:
310
349
  if self.stepRunId == "":
311
350
  return
312
351
 
313
352
  self.stream_event_thread_pool.submit(self._put_stream, data)
314
353
 
315
- def refresh_timeout(self, increment_by: str):
354
+ def refresh_timeout(self, increment_by: str) -> None:
316
355
  try:
317
356
  return self.dispatcher_client.refresh_timeout(
318
357
  step_run_id=self.stepRunId, increment_by=increment_by
@@ -320,28 +359,28 @@ class Context(BaseContext):
320
359
  except Exception as e:
321
360
  logger.error(f"Error refreshing timeout: {e}")
322
361
 
323
- def retry_count(self):
362
+ def retry_count(self) -> int:
324
363
  return self.action.retry_count
325
364
 
326
- def additional_metadata(self):
365
+ def additional_metadata(self) -> dict[str, Any] | None:
327
366
  return self.action.additional_metadata
328
367
 
329
- def child_index(self):
368
+ def child_index(self) -> int | None:
330
369
  return self.action.child_workflow_index
331
370
 
332
- def child_key(self):
371
+ def child_key(self) -> str | None:
333
372
  return self.action.child_workflow_key
334
373
 
335
- def parent_workflow_run_id(self):
374
+ def parent_workflow_run_id(self) -> str | None:
336
375
  return self.action.parent_workflow_run_id
337
376
 
338
- def fetch_run_failures(self):
377
+ def fetch_run_failures(self) -> list[dict[str, StrictStr]]:
339
378
  data = self.rest_client.workflow_run_get(self.action.workflow_run_id)
340
379
  other_job_runs = [
341
- run for run in data.job_runs if run.job_id != self.action.job_id
380
+ run for run in (data.job_runs or []) if run.job_id != self.action.job_id
342
381
  ]
343
382
  # TODO: Parse Step Runs using a Pydantic Model rather than a hand crafted dictionary
344
- failed_step_runs = [
383
+ return [
345
384
  {
346
385
  "step_id": step_run.step_id,
347
386
  "step_run_action_name": step_run.step.action,
@@ -350,7 +389,5 @@ class Context(BaseContext):
350
389
  for job_run in other_job_runs
351
390
  if job_run.step_runs
352
391
  for step_run in job_run.step_runs
353
- if step_run.error
392
+ if step_run.error and step_run.step
354
393
  ]
355
-
356
- return failed_step_runs
@@ -21,7 +21,7 @@ class WorkerContext:
21
21
  await self.client.async_upsert_worker_labels(self._worker_id, labels)
22
22
  self._labels.update(labels)
23
23
 
24
- def id(self):
24
+ def id(self) -> str:
25
25
  return self._worker_id
26
26
 
27
27
  # def has_workflow(self, workflow_name: str):
hatchet_sdk/hatchet.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import asyncio
2
2
  import logging
3
- from typing import Any, Callable, Optional, ParamSpec, TypeVar
3
+ from typing import Any, Callable, Optional, Type, TypeVar, cast, get_type_hints
4
4
 
5
+ from pydantic import BaseModel
5
6
  from typing_extensions import deprecated
6
7
 
7
8
  from hatchet_sdk.clients.rest_client import RestApi
@@ -27,14 +28,17 @@ from .clients.events import EventClient
27
28
  from .clients.run_event_listener import RunEventListenerClient
28
29
  from .logger import logger
29
30
  from .worker.worker import Worker
30
- from .workflow import ConcurrencyExpression, WorkflowMeta
31
+ from .workflow import (
32
+ ConcurrencyExpression,
33
+ WorkflowInterface,
34
+ WorkflowMeta,
35
+ WorkflowStepProtocol,
36
+ )
31
37
 
32
- P = ParamSpec("P")
33
- R = TypeVar("R")
38
+ T = TypeVar("T", bound=BaseModel)
34
39
 
35
40
 
36
- ## TODO: Fix return type here to properly type hint the metaclass
37
- def workflow( # type: ignore[no-untyped-def]
41
+ def workflow(
38
42
  name: str = "",
39
43
  on_events: list[str] | None = None,
40
44
  on_crons: list[str] | None = None,
@@ -44,11 +48,12 @@ def workflow( # type: ignore[no-untyped-def]
44
48
  sticky: StickyStrategy = None,
45
49
  default_priority: int | None = None,
46
50
  concurrency: ConcurrencyExpression | None = None,
47
- ):
51
+ input_validator: Type[T] | None = None,
52
+ ) -> Callable[[Type[WorkflowInterface]], WorkflowMeta]:
48
53
  on_events = on_events or []
49
54
  on_crons = on_crons or []
50
55
 
51
- def inner(cls: Any) -> WorkflowMeta:
56
+ def inner(cls: Type[WorkflowInterface]) -> WorkflowMeta:
52
57
  cls.on_events = on_events
53
58
  cls.on_crons = on_crons
54
59
  cls.name = name or str(cls.__name__)
@@ -62,7 +67,8 @@ def workflow( # type: ignore[no-untyped-def]
62
67
  # with WorkflowMeta as its metaclass
63
68
 
64
69
  ## TODO: Figure out how to type this metaclass correctly
65
- return WorkflowMeta(cls.name, cls.__bases__, dict(cls.__dict__)) # type: ignore[no-untyped-call]
70
+ cls.input_validator = input_validator
71
+ return WorkflowMeta(cls.name, cls.__bases__, dict(cls.__dict__))
66
72
 
67
73
  return inner
68
74
 
@@ -76,10 +82,10 @@ def step(
76
82
  desired_worker_labels: dict[str, DesiredWorkerLabel] = {},
77
83
  backoff_factor: float | None = None,
78
84
  backoff_max_seconds: int | None = None,
79
- ) -> Callable[[Callable[P, R]], Callable[P, R]]:
85
+ ) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]:
80
86
  parents = parents or []
81
87
 
82
- def inner(func: Callable[P, R]) -> Callable[P, R]:
88
+ def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol:
83
89
  limits = None
84
90
  if rate_limits:
85
91
  limits = [
@@ -87,20 +93,19 @@ def step(
87
93
  for rate_limit in rate_limits or []
88
94
  ]
89
95
 
90
- ## TODO: Use Protocol here to help with MyPy errors
91
- func._step_name = name.lower() or str(func.__name__).lower() # type: ignore[attr-defined]
92
- func._step_parents = parents # type: ignore[attr-defined]
93
- func._step_timeout = timeout # type: ignore[attr-defined]
94
- func._step_retries = retries # type: ignore[attr-defined]
95
- func._step_rate_limits = limits # type: ignore[attr-defined]
96
- func._step_backoff_factor = backoff_factor # type: ignore[attr-defined]
97
- func._step_backoff_max_seconds = backoff_max_seconds # type: ignore[attr-defined]
96
+ func._step_name = name.lower() or str(func.__name__).lower()
97
+ func._step_parents = parents
98
+ func._step_timeout = timeout
99
+ func._step_retries = retries
100
+ func._step_rate_limits = limits
101
+ func._step_backoff_factor = backoff_factor
102
+ func._step_backoff_max_seconds = backoff_max_seconds
98
103
 
99
- func._step_desired_worker_labels = {} # type: ignore[attr-defined]
104
+ func._step_desired_worker_labels = {}
100
105
 
101
106
  for key, d in desired_worker_labels.items():
102
107
  value = d["value"] if "value" in d else None
103
- func._step_desired_worker_labels[key] = DesiredWorkerLabels( # type: ignore[attr-defined]
108
+ func._step_desired_worker_labels[key] = DesiredWorkerLabels(
104
109
  strValue=str(value) if not isinstance(value, int) else None,
105
110
  intValue=value if isinstance(value, int) else None,
106
111
  required=d["required"] if "required" in d else None,
@@ -120,8 +125,8 @@ def on_failure_step(
120
125
  rate_limits: list[RateLimit] | None = None,
121
126
  backoff_factor: float | None = None,
122
127
  backoff_max_seconds: int | None = None,
123
- ) -> Callable[[Callable[P, R]], Callable[P, R]]:
124
- def inner(func: Callable[P, R]) -> Callable[P, R]:
128
+ ) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]:
129
+ def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol:
125
130
  limits = None
126
131
  if rate_limits:
127
132
  limits = [
@@ -129,13 +134,12 @@ def on_failure_step(
129
134
  for rate_limit in rate_limits or []
130
135
  ]
131
136
 
132
- ## TODO: Use Protocol here to help with MyPy errors
133
- func._on_failure_step_name = name.lower() or str(func.__name__).lower() # type: ignore[attr-defined]
134
- func._on_failure_step_timeout = timeout # type: ignore[attr-defined]
135
- func._on_failure_step_retries = retries # type: ignore[attr-defined]
136
- func._on_failure_step_rate_limits = limits # type: ignore[attr-defined]
137
- func._on_failure_step_backoff_factor = backoff_factor # type: ignore[attr-defined]
138
- func._on_failure_step_backoff_max_seconds = backoff_max_seconds # type: ignore[attr-defined]
137
+ func._on_failure_step_name = name.lower() or str(func.__name__).lower()
138
+ func._on_failure_step_timeout = timeout
139
+ func._on_failure_step_retries = retries
140
+ func._on_failure_step_rate_limits = limits
141
+ func._on_failure_step_backoff_factor = backoff_factor
142
+ func._on_failure_step_backoff_max_seconds = backoff_max_seconds
139
143
 
140
144
  return func
141
145
 
@@ -146,12 +150,11 @@ def concurrency(
146
150
  name: str = "",
147
151
  max_runs: int = 1,
148
152
  limit_strategy: ConcurrencyLimitStrategy = ConcurrencyLimitStrategy.CANCEL_IN_PROGRESS,
149
- ) -> Callable[[Callable[P, R]], Callable[P, R]]:
150
- def inner(func: Callable[P, R]) -> Callable[P, R]:
151
- ## TODO: Use Protocol here to help with MyPy errors
152
- func._concurrency_fn_name = name.lower() or str(func.__name__).lower() # type: ignore[attr-defined]
153
- func._concurrency_max_runs = max_runs # type: ignore[attr-defined]
154
- func._concurrency_limit_strategy = limit_strategy # type: ignore[attr-defined]
153
+ ) -> Callable[[WorkflowStepProtocol], WorkflowStepProtocol]:
154
+ def inner(func: WorkflowStepProtocol) -> WorkflowStepProtocol:
155
+ func._concurrency_fn_name = name.lower() or str(func.__name__).lower()
156
+ func._concurrency_max_runs = max_runs
157
+ func._concurrency_limit_strategy = limit_strategy
155
158
 
156
159
  return func
157
160
 
@@ -2,7 +2,7 @@ import asyncio
2
2
  import random
3
3
 
4
4
 
5
- async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5):
5
+ async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5) -> None:
6
6
  base_time = 0.1 # starting sleep time in seconds (100 milliseconds)
7
7
  jitter = random.uniform(0, base_time) # add random jitter
8
8
  sleep_time = min(base_time * (2**attempt) + jitter, max_sleep_time)
@@ -2,7 +2,10 @@ from typing import Any
2
2
 
3
3
 
4
4
  def flatten(xs: dict[str, Any], parent_key: str, separator: str) -> dict[str, Any]:
5
- items = []
5
+ if not xs:
6
+ return {}
7
+
8
+ items: list[tuple[str, Any]] = []
6
9
 
7
10
  for k, v in xs.items():
8
11
  new_key = parent_key + separator + k if parent_key else k
@@ -6,9 +6,9 @@ from opentelemetry import trace
6
6
  from opentelemetry.context import Context
7
7
  from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
8
8
  from opentelemetry.sdk.resources import SERVICE_NAME, Resource
9
- from opentelemetry.sdk.trace import Tracer, TracerProvider
9
+ from opentelemetry.sdk.trace import TracerProvider
10
10
  from opentelemetry.sdk.trace.export import BatchSpanProcessor
11
- from opentelemetry.trace import NoOpTracerProvider
11
+ from opentelemetry.trace import NoOpTracerProvider, Tracer
12
12
  from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
13
13
 
14
14
  from hatchet_sdk.loader import ClientConfig
@@ -44,7 +44,7 @@ def create_tracer(config: ClientConfig) -> Tracer:
44
44
 
45
45
 
46
46
  def create_carrier() -> dict[str, str]:
47
- carrier = {}
47
+ carrier: dict[str, str] = {}
48
48
  TraceContextTextMapPropagator().inject(carrier)
49
49
 
50
50
  return carrier
@@ -59,7 +59,10 @@ def inject_carrier_into_metadata(
59
59
  return metadata
60
60
 
61
61
 
62
- def parse_carrier_from_metadata(metadata: dict[str, Any]) -> Context:
62
+ def parse_carrier_from_metadata(metadata: dict[str, Any] | None) -> Context | None:
63
+ if not metadata:
64
+ return None
65
+
63
66
  return (
64
67
  TraceContextTextMapPropagator().extract(_ctx)
65
68
  if (_ctx := metadata.get(OTEL_CARRIER_KEY))
@@ -0,0 +1,8 @@
1
+ from typing import Type
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class WorkflowValidator(BaseModel):
7
+ workflow_input: Type[BaseModel] | None = None
8
+ step_output: Type[BaseModel] | None = None
@@ -0,0 +1,9 @@
1
+ from typing import Any, Type, TypeGuard, TypeVar
2
+
3
+ from pydantic import BaseModel
4
+
5
+ T = TypeVar("T", bound=BaseModel)
6
+
7
+
8
+ def is_basemodel_subclass(model: Any) -> bool:
9
+ return isinstance(model, type) and issubclass(model, BaseModel)
@@ -87,15 +87,13 @@ class WorkerActionListenerProcess:
87
87
  try:
88
88
  self.dispatcher_client = new_dispatcher(self.config)
89
89
 
90
- self.listener: ActionListener = (
91
- await self.dispatcher_client.get_action_listener(
92
- GetActionListenerRequest(
93
- worker_name=self.name,
94
- services=["default"],
95
- actions=self.actions,
96
- max_runs=self.max_runs,
97
- _labels=self.labels,
98
- )
90
+ self.listener = await self.dispatcher_client.get_action_listener(
91
+ GetActionListenerRequest(
92
+ worker_name=self.name,
93
+ services=["default"],
94
+ actions=self.actions,
95
+ max_runs=self.max_runs,
96
+ _labels=self.labels,
99
97
  )
100
98
  )
101
99