hatchet-sdk 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (73) hide show
  1. hatchet_sdk/__init__.py +32 -16
  2. hatchet_sdk/client.py +25 -63
  3. hatchet_sdk/clients/admin.py +203 -142
  4. hatchet_sdk/clients/dispatcher/action_listener.py +42 -42
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +18 -16
  6. hatchet_sdk/clients/durable_event_listener.py +327 -0
  7. hatchet_sdk/clients/rest/__init__.py +12 -1
  8. hatchet_sdk/clients/rest/api/log_api.py +258 -0
  9. hatchet_sdk/clients/rest/api/task_api.py +32 -6
  10. hatchet_sdk/clients/rest/api/workflow_runs_api.py +626 -0
  11. hatchet_sdk/clients/rest/models/__init__.py +12 -1
  12. hatchet_sdk/clients/rest/models/v1_log_line.py +94 -0
  13. hatchet_sdk/clients/rest/models/v1_log_line_level.py +39 -0
  14. hatchet_sdk/clients/rest/models/v1_log_line_list.py +110 -0
  15. hatchet_sdk/clients/rest/models/v1_task_summary.py +80 -64
  16. hatchet_sdk/clients/rest/models/v1_trigger_workflow_run_request.py +95 -0
  17. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name.py +98 -0
  18. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name_list.py +114 -0
  19. hatchet_sdk/clients/rest/models/workflow_run_shape_item_for_workflow_run_details.py +9 -4
  20. hatchet_sdk/clients/rest/models/workflow_runs_metrics.py +5 -1
  21. hatchet_sdk/clients/run_event_listener.py +0 -1
  22. hatchet_sdk/clients/v1/api_client.py +81 -0
  23. hatchet_sdk/context/context.py +86 -159
  24. hatchet_sdk/contracts/dispatcher_pb2_grpc.py +1 -1
  25. hatchet_sdk/contracts/events_pb2.py +2 -2
  26. hatchet_sdk/contracts/events_pb2_grpc.py +1 -1
  27. hatchet_sdk/contracts/v1/dispatcher_pb2.py +36 -0
  28. hatchet_sdk/contracts/v1/dispatcher_pb2.pyi +38 -0
  29. hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +145 -0
  30. hatchet_sdk/contracts/v1/shared/condition_pb2.py +39 -0
  31. hatchet_sdk/contracts/v1/shared/condition_pb2.pyi +72 -0
  32. hatchet_sdk/contracts/v1/shared/condition_pb2_grpc.py +29 -0
  33. hatchet_sdk/contracts/v1/workflows_pb2.py +67 -0
  34. hatchet_sdk/contracts/v1/workflows_pb2.pyi +228 -0
  35. hatchet_sdk/contracts/v1/workflows_pb2_grpc.py +234 -0
  36. hatchet_sdk/contracts/workflows_pb2_grpc.py +1 -1
  37. hatchet_sdk/features/cron.py +91 -121
  38. hatchet_sdk/features/logs.py +16 -0
  39. hatchet_sdk/features/metrics.py +75 -0
  40. hatchet_sdk/features/rate_limits.py +45 -0
  41. hatchet_sdk/features/runs.py +221 -0
  42. hatchet_sdk/features/scheduled.py +114 -131
  43. hatchet_sdk/features/workers.py +41 -0
  44. hatchet_sdk/features/workflows.py +55 -0
  45. hatchet_sdk/hatchet.py +463 -165
  46. hatchet_sdk/opentelemetry/instrumentor.py +8 -13
  47. hatchet_sdk/rate_limit.py +33 -39
  48. hatchet_sdk/runnables/contextvars.py +12 -0
  49. hatchet_sdk/runnables/standalone.py +192 -0
  50. hatchet_sdk/runnables/task.py +144 -0
  51. hatchet_sdk/runnables/types.py +138 -0
  52. hatchet_sdk/runnables/workflow.py +771 -0
  53. hatchet_sdk/utils/aio_utils.py +0 -79
  54. hatchet_sdk/utils/proto_enums.py +0 -7
  55. hatchet_sdk/utils/timedelta_to_expression.py +23 -0
  56. hatchet_sdk/utils/typing.py +2 -2
  57. hatchet_sdk/v0/clients/rest_client.py +9 -0
  58. hatchet_sdk/v0/worker/action_listener_process.py +18 -2
  59. hatchet_sdk/waits.py +120 -0
  60. hatchet_sdk/worker/action_listener_process.py +64 -30
  61. hatchet_sdk/worker/runner/run_loop_manager.py +35 -26
  62. hatchet_sdk/worker/runner/runner.py +72 -55
  63. hatchet_sdk/worker/runner/utils/capture_logs.py +3 -11
  64. hatchet_sdk/worker/worker.py +155 -118
  65. hatchet_sdk/workflow_run.py +4 -5
  66. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/METADATA +1 -2
  67. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/RECORD +69 -43
  68. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/entry_points.txt +2 -0
  69. hatchet_sdk/clients/rest_client.py +0 -636
  70. hatchet_sdk/semver.py +0 -30
  71. hatchet_sdk/worker/runner/utils/error_with_traceback.py +0 -6
  72. hatchet_sdk/workflow.py +0 -527
  73. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,114 @@
1
+ # coding: utf-8
2
+
3
+ """
4
+ Hatchet API
5
+
6
+ The Hatchet API
7
+
8
+ The version of the OpenAPI document: 1.0.0
9
+ Generated by OpenAPI Generator (https://openapi-generator.tech)
10
+
11
+ Do not edit the class manually.
12
+ """ # noqa: E501
13
+
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import pprint
19
+ import re # noqa: F401
20
+ from typing import Any, ClassVar, Dict, List, Optional, Set
21
+
22
+ from pydantic import BaseModel, ConfigDict, Field
23
+ from typing_extensions import Self
24
+
25
+ from hatchet_sdk.clients.rest.models.pagination_response import PaginationResponse
26
+ from hatchet_sdk.clients.rest.models.v1_workflow_run_display_name import (
27
+ V1WorkflowRunDisplayName,
28
+ )
29
+
30
+
31
+ class V1WorkflowRunDisplayNameList(BaseModel):
32
+ """
33
+ V1WorkflowRunDisplayNameList
34
+ """ # noqa: E501
35
+
36
+ pagination: PaginationResponse
37
+ rows: List[V1WorkflowRunDisplayName] = Field(
38
+ description="The list of display names"
39
+ )
40
+ __properties: ClassVar[List[str]] = ["pagination", "rows"]
41
+
42
+ model_config = ConfigDict(
43
+ populate_by_name=True,
44
+ validate_assignment=True,
45
+ protected_namespaces=(),
46
+ )
47
+
48
+ def to_str(self) -> str:
49
+ """Returns the string representation of the model using alias"""
50
+ return pprint.pformat(self.model_dump(by_alias=True))
51
+
52
+ def to_json(self) -> str:
53
+ """Returns the JSON representation of the model using alias"""
54
+ # TODO: pydantic v2: use .model_dump_json(by_alias=True, exclude_unset=True) instead
55
+ return json.dumps(self.to_dict())
56
+
57
+ @classmethod
58
+ def from_json(cls, json_str: str) -> Optional[Self]:
59
+ """Create an instance of V1WorkflowRunDisplayNameList from a JSON string"""
60
+ return cls.from_dict(json.loads(json_str))
61
+
62
+ def to_dict(self) -> Dict[str, Any]:
63
+ """Return the dictionary representation of the model using alias.
64
+
65
+ This has the following differences from calling pydantic's
66
+ `self.model_dump(by_alias=True)`:
67
+
68
+ * `None` is only added to the output dict for nullable fields that
69
+ were set at model initialization. Other fields with value `None`
70
+ are ignored.
71
+ """
72
+ excluded_fields: Set[str] = set([])
73
+
74
+ _dict = self.model_dump(
75
+ by_alias=True,
76
+ exclude=excluded_fields,
77
+ exclude_none=True,
78
+ )
79
+ # override the default output from pydantic by calling `to_dict()` of pagination
80
+ if self.pagination:
81
+ _dict["pagination"] = self.pagination.to_dict()
82
+ # override the default output from pydantic by calling `to_dict()` of each item in rows (list)
83
+ _items = []
84
+ if self.rows:
85
+ for _item_rows in self.rows:
86
+ if _item_rows:
87
+ _items.append(_item_rows.to_dict())
88
+ _dict["rows"] = _items
89
+ return _dict
90
+
91
+ @classmethod
92
+ def from_dict(cls, obj: Optional[Dict[str, Any]]) -> Optional[Self]:
93
+ """Create an instance of V1WorkflowRunDisplayNameList from a dict"""
94
+ if obj is None:
95
+ return None
96
+
97
+ if not isinstance(obj, dict):
98
+ return cls.model_validate(obj)
99
+
100
+ _obj = cls.model_validate(
101
+ {
102
+ "pagination": (
103
+ PaginationResponse.from_dict(obj["pagination"])
104
+ if obj.get("pagination") is not None
105
+ else None
106
+ ),
107
+ "rows": (
108
+ [V1WorkflowRunDisplayName.from_dict(_item) for _item in obj["rows"]]
109
+ if obj.get("rows") is not None
110
+ else None
111
+ ),
112
+ }
113
+ )
114
+ return _obj
@@ -31,13 +31,17 @@ class WorkflowRunShapeItemForWorkflowRunDetails(BaseModel):
31
31
  task_external_id: Annotated[
32
32
  str, Field(min_length=36, strict=True, max_length=36)
33
33
  ] = Field(alias="taskExternalId")
34
- children_external_ids: List[
34
+ step_id: Annotated[str, Field(min_length=36, strict=True, max_length=36)] = Field(
35
+ alias="stepId"
36
+ )
37
+ children_step_ids: List[
35
38
  Annotated[str, Field(min_length=36, strict=True, max_length=36)]
36
- ] = Field(alias="childrenExternalIds")
39
+ ] = Field(alias="childrenStepIds")
37
40
  task_name: StrictStr = Field(alias="taskName")
38
41
  __properties: ClassVar[List[str]] = [
39
42
  "taskExternalId",
40
- "childrenExternalIds",
43
+ "stepId",
44
+ "childrenStepIds",
41
45
  "taskName",
42
46
  ]
43
47
 
@@ -92,7 +96,8 @@ class WorkflowRunShapeItemForWorkflowRunDetails(BaseModel):
92
96
  _obj = cls.model_validate(
93
97
  {
94
98
  "taskExternalId": obj.get("taskExternalId"),
95
- "childrenExternalIds": obj.get("childrenExternalIds"),
99
+ "stepId": obj.get("stepId"),
100
+ "childrenStepIds": obj.get("childrenStepIds"),
96
101
  "taskName": obj.get("taskName"),
97
102
  }
98
103
  )
@@ -22,13 +22,17 @@ from typing import Any, ClassVar, Dict, List, Optional, Set
22
22
  from pydantic import BaseModel, ConfigDict
23
23
  from typing_extensions import Self
24
24
 
25
+ from hatchet_sdk.clients.rest.models.workflow_runs_metrics_counts import (
26
+ WorkflowRunsMetricsCounts,
27
+ )
28
+
25
29
 
26
30
  class WorkflowRunsMetrics(BaseModel):
27
31
  """
28
32
  WorkflowRunsMetrics
29
33
  """ # noqa: E501
30
34
 
31
- counts: Optional[Dict[str, Any]] = None
35
+ counts: Optional[WorkflowRunsMetricsCounts] = None
32
36
  __properties: ClassVar[List[str]] = ["counts"]
33
37
 
34
38
  model_config = ConfigDict(
@@ -132,7 +132,6 @@ class RunEventListener:
132
132
 
133
133
  try:
134
134
  if workflow_event.eventPayload:
135
- ## TODO: Should this be `dumps` instead?
136
135
  payload = json.loads(workflow_event.eventPayload)
137
136
  except Exception:
138
137
  payload = workflow_event.eventPayload
@@ -0,0 +1,81 @@
1
+ import asyncio
2
+ from concurrent.futures import ThreadPoolExecutor
3
+ from typing import AsyncContextManager, Callable, Coroutine, ParamSpec, TypeVar
4
+
5
+ from hatchet_sdk.clients.rest.api_client import ApiClient
6
+ from hatchet_sdk.clients.rest.configuration import Configuration
7
+ from hatchet_sdk.config import ClientConfig
8
+ from hatchet_sdk.utils.typing import JSONSerializableMapping
9
+
10
+ ## Type variables to use with coroutines.
11
+ ## See https://stackoverflow.com/questions/73240620/the-right-way-to-type-hint-a-coroutine-function
12
+ ## Return type
13
+ R = TypeVar("R")
14
+
15
+ ## Yield type
16
+ Y = TypeVar("Y")
17
+
18
+ ## Send type
19
+ S = TypeVar("S")
20
+
21
+ P = ParamSpec("P")
22
+
23
+
24
+ def maybe_additional_metadata_to_kv(
25
+ additional_metadata: dict[str, str] | JSONSerializableMapping | None
26
+ ) -> list[str] | None:
27
+ if not additional_metadata:
28
+ return None
29
+
30
+ return [f"{k}:{v}" for k, v in additional_metadata.items()]
31
+
32
+
33
+ class BaseRestClient:
34
+ def __init__(self, config: ClientConfig) -> None:
35
+ self.tenant_id = config.tenant_id
36
+
37
+ self.client_config = config
38
+ self.api_config = Configuration(
39
+ host=config.server_url,
40
+ access_token=config.token,
41
+ )
42
+
43
+ self.api_config.datetime_format = "%Y-%m-%dT%H:%M:%S.%fZ"
44
+
45
+ def client(self) -> AsyncContextManager[ApiClient]:
46
+ return ApiClient(self.api_config)
47
+
48
+ def _run_async_function_do_not_use_directly(
49
+ self,
50
+ async_func: Callable[P, Coroutine[Y, S, R]],
51
+ *args: P.args,
52
+ **kwargs: P.kwargs,
53
+ ) -> R:
54
+ loop = asyncio.new_event_loop()
55
+ asyncio.set_event_loop(loop)
56
+ try:
57
+ return loop.run_until_complete(async_func(*args, **kwargs))
58
+ finally:
59
+ loop.close()
60
+
61
+ def _run_async_from_sync(
62
+ self,
63
+ async_func: Callable[P, Coroutine[Y, S, R]],
64
+ *args: P.args,
65
+ **kwargs: P.kwargs,
66
+ ) -> R:
67
+ try:
68
+ loop = asyncio.get_event_loop()
69
+ except RuntimeError:
70
+ loop = None
71
+
72
+ if loop and loop.is_running():
73
+ return loop.run_until_complete(async_func(*args, **kwargs))
74
+ else:
75
+ with ThreadPoolExecutor() as executor:
76
+ future = executor.submit(
77
+ lambda: self._run_async_function_do_not_use_directly(
78
+ async_func, *args, **kwargs
79
+ )
80
+ )
81
+ return future.result()
@@ -2,31 +2,31 @@ import inspect
2
2
  import json
3
3
  import traceback
4
4
  from concurrent.futures import Future, ThreadPoolExecutor
5
- from typing import Any, cast
5
+ from datetime import timedelta
6
+ from typing import TYPE_CHECKING, Any, cast
6
7
 
7
8
  from pydantic import BaseModel
8
9
 
9
- from hatchet_sdk.clients.admin import (
10
- AdminClient,
11
- ChildTriggerWorkflowOptions,
12
- ChildWorkflowRunDict,
13
- TriggerWorkflowOptions,
14
- WorkflowRunDict,
15
- )
10
+ from hatchet_sdk.clients.admin import AdminClient
16
11
  from hatchet_sdk.clients.dispatcher.dispatcher import ( # type: ignore[attr-defined]
17
12
  Action,
18
13
  DispatcherClient,
19
14
  )
15
+ from hatchet_sdk.clients.durable_event_listener import (
16
+ DurableEventListener,
17
+ RegisterDurableEventRequest,
18
+ )
20
19
  from hatchet_sdk.clients.events import EventClient
21
- from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
22
- from hatchet_sdk.clients.rest_client import RestApi
23
- from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
24
- from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
25
20
  from hatchet_sdk.context.worker_context import WorkerContext
26
- from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData
27
21
  from hatchet_sdk.logger import logger
22
+ from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
28
23
  from hatchet_sdk.utils.typing import JSONSerializableMapping, WorkflowValidator
29
- from hatchet_sdk.workflow_run import WorkflowRunRef
24
+ from hatchet_sdk.waits import SleepCondition, UserEventCondition
25
+
26
+ if TYPE_CHECKING:
27
+ from hatchet_sdk.runnables.task import Task
28
+ from hatchet_sdk.runnables.types import R, TWorkflowInput
29
+
30
30
 
31
31
  DEFAULT_WORKFLOW_POLLING_INTERVAL = 5 # Seconds
32
32
 
@@ -44,19 +44,14 @@ class StepRunError(BaseModel):
44
44
 
45
45
 
46
46
  class Context:
47
- spawn_index = -1
48
-
49
47
  def __init__(
50
48
  self,
51
49
  action: Action,
52
50
  dispatcher_client: DispatcherClient,
53
51
  admin_client: AdminClient,
54
52
  event_client: EventClient,
55
- rest_client: RestApi,
56
- workflow_listener: PooledWorkflowRunListener | None,
57
- workflow_run_event_listener: RunEventListenerClient,
53
+ durable_event_listener: DurableEventListener | None,
58
54
  worker: WorkerContext,
59
- namespace: str = "",
60
55
  validator_registry: dict[str, WorkflowValidator] = {},
61
56
  ):
62
57
  self.worker = worker
@@ -66,18 +61,12 @@ class Context:
66
61
 
67
62
  self.action = action
68
63
 
69
- # FIXME: stepRunId is a legacy field, we should remove it
70
- self.stepRunId = action.step_run_id
71
-
72
- self.step_run_id: str = action.step_run_id
64
+ self.step_run_id = action.step_run_id
73
65
  self.exit_flag = False
74
66
  self.dispatcher_client = dispatcher_client
75
67
  self.admin_client = admin_client
76
68
  self.event_client = event_client
77
- self.rest_client = rest_client
78
- self.workflow_listener = workflow_listener
79
- self.workflow_run_event_listener = workflow_run_event_listener
80
- self.namespace = namespace
69
+ self.durable_event_listener = durable_event_listener
81
70
 
82
71
  # FIXME: this limits the number of concurrent log requests to 1, which means we can do about
83
72
  # 100 log lines per second but this depends on network.
@@ -86,45 +75,46 @@ class Context:
86
75
 
87
76
  self.input = self.data.input
88
77
 
89
- def _prepare_workflow_options(
90
- self,
91
- key: str | None = None,
92
- options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(),
93
- worker_id: str | None = None,
94
- ) -> TriggerWorkflowOptions:
95
- workflow_run_id = self.action.workflow_run_id
96
- step_run_id = self.action.step_run_id
97
-
98
- trigger_options = TriggerWorkflowOptions(
99
- parent_id=workflow_run_id,
100
- parent_step_run_id=step_run_id,
101
- child_key=key,
102
- child_index=self.spawn_index,
103
- additional_metadata=options.additional_metadata,
104
- desired_worker_id=worker_id if options.sticky else None,
105
- )
78
+ def was_skipped(self, task: "Task[TWorkflowInput, R]") -> bool:
79
+ return self.data.parents.get(task.name, {}).get("skipped", False)
80
+
81
+ @property
82
+ def trigger_data(self) -> JSONSerializableMapping:
83
+ return self.data.triggers
84
+
85
+ def task_output(self, task: "Task[TWorkflowInput, R]") -> "R":
86
+ from hatchet_sdk.runnables.types import R
106
87
 
107
- self.spawn_index += 1
108
- return trigger_options
88
+ if self.was_skipped(task):
89
+ raise ValueError("{task.name} was skipped")
90
+
91
+ action_prefix = self.action.action_id.split(":")[0]
109
92
 
110
- def step_output(self, step: str) -> dict[str, Any] | BaseModel:
111
93
  workflow_validator = next(
112
- (v for k, v in self.validator_registry.items() if k.split(":")[-1] == step),
94
+ (
95
+ v
96
+ for k, v in self.validator_registry.items()
97
+ if k == f"{action_prefix}:{task.name}"
98
+ ),
113
99
  None,
114
100
  )
115
101
 
116
102
  try:
117
- parent_step_data = cast(dict[str, Any], self.data.parents[step])
103
+ parent_step_data = cast(R, self.data.parents[task.name])
118
104
  except KeyError:
119
- raise ValueError(f"Step output for '{step}' not found")
105
+ raise ValueError(f"Step output for '{task.name}' not found")
120
106
 
121
- if workflow_validator and (v := workflow_validator.step_output):
122
- return v.model_validate(parent_step_data)
107
+ if (
108
+ parent_step_data
109
+ and workflow_validator
110
+ and (v := workflow_validator.step_output)
111
+ ):
112
+ return cast(R, v.model_validate(parent_step_data))
123
113
 
124
114
  return parent_step_data
125
115
 
126
116
  @property
127
- def triggered_by_event(self) -> bool:
117
+ def was_triggered_by_event(self) -> bool:
128
118
  return self.data.triggered_by == "event"
129
119
 
130
120
  @property
@@ -143,23 +133,9 @@ class Context:
143
133
  def done(self) -> bool:
144
134
  return self.exit_flag
145
135
 
146
- def playground(self, name: str, default: str | None = None) -> str | None:
147
- caller_file = get_caller_file_path()
148
-
149
- self.dispatcher_client.put_overrides_data(
150
- OverridesData(
151
- stepRunId=self.stepRunId,
152
- path=name,
153
- value=json.dumps(default),
154
- callerFilename=caller_file,
155
- )
156
- )
157
-
158
- return default
159
-
160
136
  def _log(self, line: str) -> tuple[bool, Exception | None]:
161
137
  try:
162
- self.event_client.log(message=line, step_run_id=self.stepRunId)
138
+ self.event_client.log(message=line, step_run_id=self.step_run_id)
163
139
  return True, None
164
140
  except Exception as e:
165
141
  # we don't want to raise an exception here, as it will kill the log thread
@@ -168,7 +144,7 @@ class Context:
168
144
  def log(
169
145
  self, line: str | JSONSerializableMapping, raise_on_error: bool = False
170
146
  ) -> None:
171
- if self.stepRunId == "":
147
+ if self.step_run_id == "":
172
148
  return
173
149
 
174
150
  if not isinstance(line, str):
@@ -198,24 +174,27 @@ class Context:
198
174
  future.add_done_callback(handle_result)
199
175
 
200
176
  def release_slot(self) -> None:
201
- return self.dispatcher_client.release_slot(self.stepRunId)
177
+ return self.dispatcher_client.release_slot(self.step_run_id)
202
178
 
203
179
  def _put_stream(self, data: str | bytes) -> None:
204
180
  try:
205
- self.event_client.stream(data=data, step_run_id=self.stepRunId)
181
+ self.event_client.stream(data=data, step_run_id=self.step_run_id)
206
182
  except Exception as e:
207
183
  logger.error(f"Error putting stream event: {e}")
208
184
 
209
185
  def put_stream(self, data: str | bytes) -> None:
210
- if self.stepRunId == "":
186
+ if self.step_run_id == "":
211
187
  return
212
188
 
213
189
  self.stream_event_thread_pool.submit(self._put_stream, data)
214
190
 
215
- def refresh_timeout(self, increment_by: str) -> None:
191
+ def refresh_timeout(self, increment_by: str | timedelta) -> None:
192
+ if isinstance(increment_by, timedelta):
193
+ increment_by = timedelta_to_expr(increment_by)
194
+
216
195
  try:
217
196
  return self.dispatcher_client.refresh_timeout(
218
- step_run_id=self.stepRunId, increment_by=increment_by
197
+ step_run_id=self.step_run_id, increment_by=increment_by
219
198
  )
220
199
  except Exception as e:
221
200
  logger.error(f"Error refreshing timeout: {e}")
@@ -241,7 +220,7 @@ class Context:
241
220
  return self.action.parent_workflow_run_id
242
221
 
243
222
  @property
244
- def step_run_errors(self) -> dict[str, str]:
223
+ def task_run_errors(self) -> dict[str, str]:
245
224
  errors = self.data.step_run_errors
246
225
 
247
226
  if not errors:
@@ -251,96 +230,44 @@ class Context:
251
230
 
252
231
  return errors
253
232
 
254
- def fetch_run_failures(self) -> list[StepRunError]:
255
- data = self.rest_client.workflow_run_get(self.action.workflow_run_id)
256
- other_job_runs = [
257
- run for run in (data.job_runs or []) if run.job_id != self.action.job_id
258
- ]
259
-
260
- return [
261
- StepRunError(
262
- step_id=step_run.step_id,
263
- step_run_action_name=step_run.step.action,
264
- error=step_run.error,
265
- )
266
- for job_run in other_job_runs
267
- if job_run.step_runs
268
- for step_run in job_run.step_runs
269
- if step_run.error and step_run.step
270
- ]
271
-
272
- @tenacity_retry
273
- async def aio_spawn_workflow(
233
+ def fetch_task_run_error(
274
234
  self,
275
- workflow_name: str,
276
- input: JSONSerializableMapping = {},
277
- key: str | None = None,
278
- options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(),
279
- ) -> WorkflowRunRef:
280
- worker_id = self.worker.id()
235
+ task: "Task[TWorkflowInput, R]",
236
+ ) -> str | None:
237
+ errors = self.data.step_run_errors
281
238
 
282
- trigger_options = self._prepare_workflow_options(key, options, worker_id)
239
+ return errors.get(task.name)
283
240
 
284
- return await self.admin_client.aio_run_workflow(
285
- workflow_name, input, trigger_options
286
- )
287
241
 
288
- @tenacity_retry
289
- async def aio_spawn_workflows(
290
- self, child_workflow_runs: list[ChildWorkflowRunDict]
291
- ) -> list[WorkflowRunRef]:
242
+ class DurableContext(Context):
243
+ async def aio_wait_for(
244
+ self, signal_key: str, *conditions: SleepCondition | UserEventCondition
245
+ ) -> dict[str, Any]:
246
+ if self.durable_event_listener is None:
247
+ raise ValueError("Durable event listener is not available")
292
248
 
293
- if len(child_workflow_runs) == 0:
294
- raise Exception("no child workflows to spawn")
249
+ task_id = self.step_run_id
295
250
 
296
- worker_id = self.worker.id()
251
+ request = RegisterDurableEventRequest(
252
+ task_id=task_id,
253
+ signal_key=signal_key,
254
+ conditions=list(conditions),
255
+ )
297
256
 
298
- bulk_trigger_workflow_runs = [
299
- WorkflowRunDict(
300
- workflow_name=child_workflow_run.workflow_name,
301
- input=child_workflow_run.input,
302
- options=self._prepare_workflow_options(
303
- child_workflow_run.key, child_workflow_run.options, worker_id
304
- ),
305
- )
306
- for child_workflow_run in child_workflow_runs
307
- ]
257
+ self.durable_event_listener.register_durable_event(request)
308
258
 
309
- return await self.admin_client.aio_run_workflows(bulk_trigger_workflow_runs)
259
+ return await self.durable_event_listener.result(
260
+ task_id,
261
+ signal_key,
262
+ )
310
263
 
311
- @tenacity_retry
312
- def spawn_workflow(
313
- self,
314
- workflow_name: str,
315
- input: JSONSerializableMapping = {},
316
- key: str | None = None,
317
- options: ChildTriggerWorkflowOptions = ChildTriggerWorkflowOptions(),
318
- ) -> WorkflowRunRef:
319
- worker_id = self.worker.id()
320
-
321
- trigger_options = self._prepare_workflow_options(key, options, worker_id)
322
-
323
- return self.admin_client.run_workflow(workflow_name, input, trigger_options)
324
-
325
- @tenacity_retry
326
- def spawn_workflows(
327
- self, child_workflow_runs: list[ChildWorkflowRunDict]
328
- ) -> list[WorkflowRunRef]:
329
-
330
- if len(child_workflow_runs) == 0:
331
- raise Exception("no child workflows to spawn")
332
-
333
- worker_id = self.worker.id()
334
-
335
- bulk_trigger_workflow_runs = [
336
- WorkflowRunDict(
337
- workflow_name=child_workflow_run.workflow_name,
338
- input=child_workflow_run.input,
339
- options=self._prepare_workflow_options(
340
- child_workflow_run.key, child_workflow_run.options, worker_id
341
- ),
342
- )
343
- for child_workflow_run in child_workflow_runs
344
- ]
264
+ async def aio_sleep_for(self, duration: Duration) -> dict[str, Any]:
265
+ """
266
+ Lightweight wrapper for durable sleep. Allows for shorthand usage of `ctx.aio_wait_for` when specifying a sleep condition.
267
+
268
+ For more complicated conditions, use `ctx.aio_wait_for` directly.
269
+ """
345
270
 
346
- return self.admin_client.run_workflows(bulk_trigger_workflow_runs)
271
+ return await self.aio_wait_for(
272
+ f"sleep:{timedelta_to_expr(duration)}", SleepCondition(duration=duration)
273
+ )
@@ -3,7 +3,7 @@
3
3
  import grpc
4
4
  import warnings
5
5
 
6
- from . import dispatcher_pb2 as dispatcher__pb2
6
+ from hatchet_sdk.contracts import dispatcher_pb2 as dispatcher__pb2
7
7
 
8
8
  GRPC_GENERATED_VERSION = '1.64.1'
9
9
  GRPC_VERSION = grpc.__version__
@@ -15,14 +15,14 @@ _sym_db = _symbol_database.Default()
15
15
  from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
16
16
 
17
17
 
18
- DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x65vents.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xb4\x01\n\x05\x45vent\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x0f\n\x07\x65ventId\x18\x02 \x01(\t\x12\x0b\n\x03key\x18\x03 \x01(\t\x12\x0f\n\x07payload\x18\x04 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x05 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1f\n\x12\x61\x64\x64itionalMetadata\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\x15\n\x13_additionalMetadata\" \n\x06\x45vents\x12\x16\n\x06\x65vents\x18\x01 \x03(\x0b\x32\x06.Event\"\x92\x01\n\rPutLogRequest\x12\x11\n\tstepRunId\x18\x01 \x01(\t\x12-\n\tcreatedAt\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0f\n\x07message\x18\x03 \x01(\t\x12\x12\n\x05level\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x10\n\x08metadata\x18\x05 \x01(\tB\x08\n\x06_level\"\x10\n\x0ePutLogResponse\"|\n\x15PutStreamEventRequest\x12\x11\n\tstepRunId\x18\x01 \x01(\t\x12-\n\tcreatedAt\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0f\n\x07message\x18\x03 \x01(\x0c\x12\x10\n\x08metadata\x18\x05 \x01(\t\"\x18\n\x16PutStreamEventResponse\"9\n\x14\x42ulkPushEventRequest\x12!\n\x06\x65vents\x18\x01 \x03(\x0b\x32\x11.PushEventRequest\"\x9c\x01\n\x10PushEventRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0f\n\x07payload\x18\x02 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1f\n\x12\x61\x64\x64itionalMetadata\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\x15\n\x13_additionalMetadata\"%\n\x12ReplayEventRequest\x12\x0f\n\x07\x65ventId\x18\x01 \x01(\t2\x88\x02\n\rEventsService\x12#\n\x04Push\x12\x11.PushEventRequest\x1a\x06.Event\"\x00\x12,\n\x08\x42ulkPush\x12\x15.BulkPushEventRequest\x1a\x07.Events\"\x00\x12\x32\n\x11ReplaySingleEvent\x12\x13.ReplayEventRequest\x1a\x06.Event\"\x00\x12+\n\x06PutLog\x12\x0e.PutLogRequest\x1a\x0f.PutLogResponse\"\x00\x12\x43\n\x0ePutStreamEvent\x12\x16.PutStreamEventRequest\x1a\x17.PutStreamEventResponse\"\x00\x42GZEgithub.com/hatchet-dev/hatchet/internal/services/dispatcher/contractsb\x06proto3')
18
+ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0c\x65vents.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\xb4\x01\n\x05\x45vent\x12\x10\n\x08tenantId\x18\x01 \x01(\t\x12\x0f\n\x07\x65ventId\x18\x02 \x01(\t\x12\x0b\n\x03key\x18\x03 \x01(\t\x12\x0f\n\x07payload\x18\x04 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x05 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1f\n\x12\x61\x64\x64itionalMetadata\x18\x06 \x01(\tH\x00\x88\x01\x01\x42\x15\n\x13_additionalMetadata\" \n\x06\x45vents\x12\x16\n\x06\x65vents\x18\x01 \x03(\x0b\x32\x06.Event\"\x92\x01\n\rPutLogRequest\x12\x11\n\tstepRunId\x18\x01 \x01(\t\x12-\n\tcreatedAt\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0f\n\x07message\x18\x03 \x01(\t\x12\x12\n\x05level\x18\x04 \x01(\tH\x00\x88\x01\x01\x12\x10\n\x08metadata\x18\x05 \x01(\tB\x08\n\x06_level\"\x10\n\x0ePutLogResponse\"|\n\x15PutStreamEventRequest\x12\x11\n\tstepRunId\x18\x01 \x01(\t\x12-\n\tcreatedAt\x18\x02 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0f\n\x07message\x18\x03 \x01(\x0c\x12\x10\n\x08metadata\x18\x05 \x01(\t\"\x18\n\x16PutStreamEventResponse\"9\n\x14\x42ulkPushEventRequest\x12!\n\x06\x65vents\x18\x01 \x03(\x0b\x32\x11.PushEventRequest\"\x9c\x01\n\x10PushEventRequest\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x0f\n\x07payload\x18\x02 \x01(\t\x12\x32\n\x0e\x65ventTimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1f\n\x12\x61\x64\x64itionalMetadata\x18\x04 \x01(\tH\x00\x88\x01\x01\x42\x15\n\x13_additionalMetadata\"%\n\x12ReplayEventRequest\x12\x0f\n\x07\x65ventId\x18\x01 \x01(\t2\x88\x02\n\rEventsService\x12#\n\x04Push\x12\x11.PushEventRequest\x1a\x06.Event\"\x00\x12,\n\x08\x42ulkPush\x12\x15.BulkPushEventRequest\x1a\x07.Events\"\x00\x12\x32\n\x11ReplaySingleEvent\x12\x13.ReplayEventRequest\x1a\x06.Event\"\x00\x12+\n\x06PutLog\x12\x0e.PutLogRequest\x1a\x0f.PutLogResponse\"\x00\x12\x43\n\x0ePutStreamEvent\x12\x16.PutStreamEventRequest\x1a\x17.PutStreamEventResponse\"\x00\x42\x45ZCgithub.com/hatchet-dev/hatchet/internal/services/ingestor/contractsb\x06proto3')
19
19
 
20
20
  _globals = globals()
21
21
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
22
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'events_pb2', _globals)
23
23
  if not _descriptor._USE_C_DESCRIPTORS:
24
24
  _globals['DESCRIPTOR']._loaded_options = None
25
- _globals['DESCRIPTOR']._serialized_options = b'ZEgithub.com/hatchet-dev/hatchet/internal/services/dispatcher/contracts'
25
+ _globals['DESCRIPTOR']._serialized_options = b'ZCgithub.com/hatchet-dev/hatchet/internal/services/ingestor/contracts'
26
26
  _globals['_EVENT']._serialized_start=50
27
27
  _globals['_EVENT']._serialized_end=230
28
28
  _globals['_EVENTS']._serialized_start=232
@@ -3,7 +3,7 @@
3
3
  import grpc
4
4
  import warnings
5
5
 
6
- from . import events_pb2 as events__pb2
6
+ from hatchet_sdk.contracts import events_pb2 as events__pb2
7
7
 
8
8
  GRPC_GENERATED_VERSION = '1.64.1'
9
9
  GRPC_VERSION = grpc.__version__