hatchet-sdk 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (73) hide show
  1. hatchet_sdk/__init__.py +32 -16
  2. hatchet_sdk/client.py +25 -63
  3. hatchet_sdk/clients/admin.py +203 -142
  4. hatchet_sdk/clients/dispatcher/action_listener.py +42 -42
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +18 -16
  6. hatchet_sdk/clients/durable_event_listener.py +327 -0
  7. hatchet_sdk/clients/rest/__init__.py +12 -1
  8. hatchet_sdk/clients/rest/api/log_api.py +258 -0
  9. hatchet_sdk/clients/rest/api/task_api.py +32 -6
  10. hatchet_sdk/clients/rest/api/workflow_runs_api.py +626 -0
  11. hatchet_sdk/clients/rest/models/__init__.py +12 -1
  12. hatchet_sdk/clients/rest/models/v1_log_line.py +94 -0
  13. hatchet_sdk/clients/rest/models/v1_log_line_level.py +39 -0
  14. hatchet_sdk/clients/rest/models/v1_log_line_list.py +110 -0
  15. hatchet_sdk/clients/rest/models/v1_task_summary.py +80 -64
  16. hatchet_sdk/clients/rest/models/v1_trigger_workflow_run_request.py +95 -0
  17. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name.py +98 -0
  18. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name_list.py +114 -0
  19. hatchet_sdk/clients/rest/models/workflow_run_shape_item_for_workflow_run_details.py +9 -4
  20. hatchet_sdk/clients/rest/models/workflow_runs_metrics.py +5 -1
  21. hatchet_sdk/clients/run_event_listener.py +0 -1
  22. hatchet_sdk/clients/v1/api_client.py +81 -0
  23. hatchet_sdk/context/context.py +86 -159
  24. hatchet_sdk/contracts/dispatcher_pb2_grpc.py +1 -1
  25. hatchet_sdk/contracts/events_pb2.py +2 -2
  26. hatchet_sdk/contracts/events_pb2_grpc.py +1 -1
  27. hatchet_sdk/contracts/v1/dispatcher_pb2.py +36 -0
  28. hatchet_sdk/contracts/v1/dispatcher_pb2.pyi +38 -0
  29. hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +145 -0
  30. hatchet_sdk/contracts/v1/shared/condition_pb2.py +39 -0
  31. hatchet_sdk/contracts/v1/shared/condition_pb2.pyi +72 -0
  32. hatchet_sdk/contracts/v1/shared/condition_pb2_grpc.py +29 -0
  33. hatchet_sdk/contracts/v1/workflows_pb2.py +67 -0
  34. hatchet_sdk/contracts/v1/workflows_pb2.pyi +228 -0
  35. hatchet_sdk/contracts/v1/workflows_pb2_grpc.py +234 -0
  36. hatchet_sdk/contracts/workflows_pb2_grpc.py +1 -1
  37. hatchet_sdk/features/cron.py +91 -121
  38. hatchet_sdk/features/logs.py +16 -0
  39. hatchet_sdk/features/metrics.py +75 -0
  40. hatchet_sdk/features/rate_limits.py +45 -0
  41. hatchet_sdk/features/runs.py +221 -0
  42. hatchet_sdk/features/scheduled.py +114 -131
  43. hatchet_sdk/features/workers.py +41 -0
  44. hatchet_sdk/features/workflows.py +55 -0
  45. hatchet_sdk/hatchet.py +463 -165
  46. hatchet_sdk/opentelemetry/instrumentor.py +8 -13
  47. hatchet_sdk/rate_limit.py +33 -39
  48. hatchet_sdk/runnables/contextvars.py +12 -0
  49. hatchet_sdk/runnables/standalone.py +192 -0
  50. hatchet_sdk/runnables/task.py +144 -0
  51. hatchet_sdk/runnables/types.py +138 -0
  52. hatchet_sdk/runnables/workflow.py +771 -0
  53. hatchet_sdk/utils/aio_utils.py +0 -79
  54. hatchet_sdk/utils/proto_enums.py +0 -7
  55. hatchet_sdk/utils/timedelta_to_expression.py +23 -0
  56. hatchet_sdk/utils/typing.py +2 -2
  57. hatchet_sdk/v0/clients/rest_client.py +9 -0
  58. hatchet_sdk/v0/worker/action_listener_process.py +18 -2
  59. hatchet_sdk/waits.py +120 -0
  60. hatchet_sdk/worker/action_listener_process.py +64 -30
  61. hatchet_sdk/worker/runner/run_loop_manager.py +35 -26
  62. hatchet_sdk/worker/runner/runner.py +72 -55
  63. hatchet_sdk/worker/runner/utils/capture_logs.py +3 -11
  64. hatchet_sdk/worker/worker.py +155 -118
  65. hatchet_sdk/workflow_run.py +4 -5
  66. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/METADATA +1 -2
  67. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/RECORD +69 -43
  68. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/entry_points.txt +2 -0
  69. hatchet_sdk/clients/rest_client.py +0 -636
  70. hatchet_sdk/semver.py +0 -30
  71. hatchet_sdk/worker/runner/utils/error_with_traceback.py +0 -6
  72. hatchet_sdk/workflow.py +0 -527
  73. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/WHEEL +0 -0
@@ -5,18 +5,27 @@ from typing import Union, cast
5
5
 
6
6
  import grpc
7
7
  from google.protobuf import timestamp_pb2
8
- from pydantic import BaseModel, Field
8
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
9
9
 
10
10
  from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
11
11
  from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
12
12
  from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
13
13
  from hatchet_sdk.config import ClientConfig
14
14
  from hatchet_sdk.connection import new_conn
15
- from hatchet_sdk.contracts import workflows_pb2 as workflow_protos
15
+ from hatchet_sdk.contracts import workflows_pb2 as v0_workflow_protos
16
+ from hatchet_sdk.contracts.v1 import workflows_pb2 as workflow_protos
17
+ from hatchet_sdk.contracts.v1.workflows_pb2_grpc import AdminServiceStub
16
18
  from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub
17
19
  from hatchet_sdk.metadata import get_metadata
18
20
  from hatchet_sdk.rate_limit import RateLimitDuration
19
- from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto, maybe_int_to_str
21
+ from hatchet_sdk.runnables.contextvars import (
22
+ ctx_step_run_id,
23
+ ctx_worker_id,
24
+ ctx_workflow_run_id,
25
+ spawn_index_lock,
26
+ workflow_spawn_indices,
27
+ )
28
+ from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
20
29
  from hatchet_sdk.utils.typing import JSONSerializableMapping
21
30
  from hatchet_sdk.workflow_run import WorkflowRunRef
22
31
 
@@ -29,28 +38,19 @@ class ScheduleTriggerWorkflowOptions(BaseModel):
29
38
  namespace: str | None = None
30
39
 
31
40
 
32
- class ChildTriggerWorkflowOptions(BaseModel):
33
- additional_metadata: JSONSerializableMapping = Field(default_factory=dict)
34
- sticky: bool = False
35
-
36
-
37
- class ChildWorkflowRunDict(BaseModel):
38
- workflow_name: str
39
- input: JSONSerializableMapping
40
- options: ChildTriggerWorkflowOptions
41
- key: str | None = None
42
-
43
-
44
41
  class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions):
45
42
  additional_metadata: JSONSerializableMapping = Field(default_factory=dict)
46
43
  desired_worker_id: str | None = None
47
44
  namespace: str | None = None
45
+ sticky: bool = False
46
+ key: str | None = None
48
47
 
49
48
 
50
- class WorkflowRunDict(BaseModel):
49
+ class WorkflowRunTriggerConfig(BaseModel):
51
50
  workflow_name: str
52
51
  input: JSONSerializableMapping
53
52
  options: TriggerWorkflowOptions
53
+ key: str | None = None
54
54
 
55
55
 
56
56
  class DedupeViolationErr(Exception):
@@ -63,55 +63,69 @@ class AdminClient:
63
63
  def __init__(self, config: ClientConfig):
64
64
  conn = new_conn(config, False)
65
65
  self.config = config
66
- self.client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
66
+ self.client = AdminServiceStub(conn) # type: ignore[no-untyped-call]
67
+ self.v0_client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
67
68
  self.token = config.token
68
69
  self.listener_client = RunEventListenerClient(config=config)
69
70
  self.namespace = config.namespace
70
71
 
71
72
  self.pooled_workflow_listener: PooledWorkflowRunListener | None = None
72
73
 
74
+ class TriggerWorkflowRequest(BaseModel):
75
+ model_config = ConfigDict(extra="ignore")
76
+
77
+ parent_id: str | None = None
78
+ parent_step_run_id: str | None = None
79
+ child_index: int | None = None
80
+ child_key: str | None = None
81
+ additional_metadata: str | None = None
82
+ desired_worker_id: str | None = None
83
+ priority: int | None = None
84
+
85
+ @field_validator("additional_metadata", mode="before")
86
+ @classmethod
87
+ def validate_additional_metadata(
88
+ cls, v: JSONSerializableMapping | None
89
+ ) -> bytes | None:
90
+ if not v:
91
+ return None
92
+
93
+ try:
94
+ return json.dumps(v).encode("utf-8")
95
+ except json.JSONDecodeError as e:
96
+ raise ValueError(f"Error encoding payload: {e}")
97
+
73
98
  def _prepare_workflow_request(
74
99
  self,
75
100
  workflow_name: str,
76
101
  input: JSONSerializableMapping,
77
102
  options: TriggerWorkflowOptions,
78
- ) -> workflow_protos.TriggerWorkflowRequest:
103
+ ) -> v0_workflow_protos.TriggerWorkflowRequest:
79
104
  try:
80
105
  payload_data = json.dumps(input)
81
- _options = options.model_dump()
82
-
83
- _options.pop("namespace")
84
-
85
- try:
86
- _options = {
87
- **_options,
88
- "additional_metadata": json.dumps(
89
- options.additional_metadata
90
- ).encode("utf-8"),
91
- }
92
- except json.JSONDecodeError as e:
93
- raise ValueError(f"Error encoding payload: {e}")
94
-
95
- return workflow_protos.TriggerWorkflowRequest(
96
- name=workflow_name, input=payload_data, **_options
97
- )
98
106
  except json.JSONDecodeError as e:
99
107
  raise ValueError(f"Error encoding payload: {e}")
100
108
 
109
+ _options = self.TriggerWorkflowRequest.model_validate(
110
+ options.model_dump()
111
+ ).model_dump()
112
+
113
+ return v0_workflow_protos.TriggerWorkflowRequest(
114
+ name=workflow_name, input=payload_data, **_options
115
+ )
116
+
101
117
  def _prepare_put_workflow_request(
102
118
  self,
103
119
  name: str,
104
- workflow: workflow_protos.CreateWorkflowVersionOpts,
105
- overrides: workflow_protos.CreateWorkflowVersionOpts | None = None,
106
- ) -> workflow_protos.PutWorkflowRequest:
120
+ workflow: workflow_protos.CreateWorkflowVersionRequest,
121
+ overrides: workflow_protos.CreateWorkflowVersionRequest | None = None,
122
+ ) -> workflow_protos.CreateWorkflowVersionRequest:
107
123
  if overrides is not None:
108
124
  workflow.MergeFrom(overrides)
109
125
 
110
126
  workflow.name = name
111
127
 
112
- return workflow_protos.PutWorkflowRequest(
113
- opts=workflow,
114
- )
128
+ return workflow
115
129
 
116
130
  def _parse_schedule(
117
131
  self, schedule: datetime | timestamp_pb2.Timestamp
@@ -134,57 +148,21 @@ class AdminClient:
134
148
  schedules: list[Union[datetime, timestamp_pb2.Timestamp]],
135
149
  input: JSONSerializableMapping = {},
136
150
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
137
- ) -> workflow_protos.ScheduleWorkflowRequest:
138
- return workflow_protos.ScheduleWorkflowRequest(
151
+ ) -> v0_workflow_protos.ScheduleWorkflowRequest:
152
+ return v0_workflow_protos.ScheduleWorkflowRequest(
139
153
  name=name,
140
154
  schedules=[self._parse_schedule(schedule) for schedule in schedules],
141
155
  input=json.dumps(input),
142
156
  **options.model_dump(),
143
157
  )
144
158
 
145
- ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
146
- @tenacity_retry
147
- async def aio_run_workflow(
148
- self,
149
- workflow_name: str,
150
- input: JSONSerializableMapping,
151
- options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
152
- ) -> WorkflowRunRef:
153
- ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
154
- ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
155
- ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
156
- if not self.pooled_workflow_listener:
157
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
158
-
159
- return await asyncio.to_thread(self.run_workflow, workflow_name, input, options)
160
-
161
- @tenacity_retry
162
- async def aio_run_workflows(
163
- self,
164
- workflows: list[WorkflowRunDict],
165
- options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
166
- ) -> list[WorkflowRunRef]:
167
- ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
168
- ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
169
- ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
170
- if not self.pooled_workflow_listener:
171
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
172
-
173
- return await asyncio.to_thread(self.run_workflows, workflows, options)
174
-
175
159
  @tenacity_retry
176
160
  async def aio_put_workflow(
177
161
  self,
178
162
  name: str,
179
- workflow: workflow_protos.CreateWorkflowVersionOpts,
180
- overrides: workflow_protos.CreateWorkflowVersionOpts | None = None,
181
- ) -> workflow_protos.WorkflowVersion:
182
- ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
183
- ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
184
- ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
185
- if not self.pooled_workflow_listener:
186
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
187
-
163
+ workflow: workflow_protos.CreateWorkflowVersionRequest,
164
+ overrides: workflow_protos.CreateWorkflowVersionRequest | None = None,
165
+ ) -> workflow_protos.CreateWorkflowVersionResponse:
188
166
  return await asyncio.to_thread(self.put_workflow, name, workflow, overrides)
189
167
 
190
168
  @tenacity_retry
@@ -194,12 +172,6 @@ class AdminClient:
194
172
  limit: int,
195
173
  duration: RateLimitDuration = RateLimitDuration.SECOND,
196
174
  ) -> None:
197
- ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
198
- ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
199
- ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
200
- if not self.pooled_workflow_listener:
201
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
202
-
203
175
  return await asyncio.to_thread(self.put_rate_limit, key, limit, duration)
204
176
 
205
177
  @tenacity_retry
@@ -209,13 +181,7 @@ class AdminClient:
209
181
  schedules: list[Union[datetime, timestamp_pb2.Timestamp]],
210
182
  input: JSONSerializableMapping = {},
211
183
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
212
- ) -> workflow_protos.WorkflowVersion:
213
- ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
214
- ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
215
- ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
216
- if not self.pooled_workflow_listener:
217
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
218
-
184
+ ) -> v0_workflow_protos.WorkflowVersion:
219
185
  return await asyncio.to_thread(
220
186
  self.schedule_workflow, name, schedules, input, options
221
187
  )
@@ -224,18 +190,19 @@ class AdminClient:
224
190
  def put_workflow(
225
191
  self,
226
192
  name: str,
227
- workflow: workflow_protos.CreateWorkflowVersionOpts,
228
- overrides: workflow_protos.CreateWorkflowVersionOpts | None = None,
229
- ) -> workflow_protos.WorkflowVersion:
193
+ workflow: workflow_protos.CreateWorkflowVersionRequest,
194
+ overrides: workflow_protos.CreateWorkflowVersionRequest | None = None,
195
+ ) -> workflow_protos.CreateWorkflowVersionResponse:
230
196
  opts = self._prepare_put_workflow_request(name, workflow, overrides)
231
197
 
232
- resp: workflow_protos.WorkflowVersion = self.client.PutWorkflow(
233
- opts,
234
- metadata=get_metadata(self.token),
198
+ return cast(
199
+ workflow_protos.CreateWorkflowVersionResponse,
200
+ self.client.PutWorkflow(
201
+ opts,
202
+ metadata=get_metadata(self.token),
203
+ ),
235
204
  )
236
205
 
237
- return resp
238
-
239
206
  @tenacity_retry
240
207
  def put_rate_limit(
241
208
  self,
@@ -246,11 +213,12 @@ class AdminClient:
246
213
  duration_proto = convert_python_enum_to_proto(
247
214
  duration, workflow_protos.RateLimitDuration
248
215
  )
249
- self.client.PutRateLimit(
250
- workflow_protos.PutRateLimitRequest(
216
+
217
+ self.v0_client.PutRateLimit(
218
+ v0_workflow_protos.PutRateLimitRequest(
251
219
  key=key,
252
220
  limit=limit,
253
- duration=maybe_int_to_str(duration_proto),
221
+ duration=duration_proto, # type: ignore[arg-type]
254
222
  ),
255
223
  metadata=get_metadata(self.token),
256
224
  )
@@ -262,7 +230,7 @@ class AdminClient:
262
230
  schedules: list[Union[datetime, timestamp_pb2.Timestamp]],
263
231
  input: JSONSerializableMapping = {},
264
232
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
265
- ) -> workflow_protos.WorkflowVersion:
233
+ ) -> v0_workflow_protos.WorkflowVersion:
266
234
  try:
267
235
  namespace = options.namespace or self.namespace
268
236
 
@@ -274,8 +242,8 @@ class AdminClient:
274
242
  )
275
243
 
276
244
  return cast(
277
- workflow_protos.WorkflowVersion,
278
- self.client.ScheduleWorkflow(
245
+ v0_workflow_protos.WorkflowVersion,
246
+ self.v0_client.ScheduleWorkflow(
279
247
  request,
280
248
  metadata=get_metadata(self.token),
281
249
  ),
@@ -286,6 +254,44 @@ class AdminClient:
286
254
 
287
255
  raise e
288
256
 
257
+ def _create_workflow_run_request(
258
+ self,
259
+ workflow_name: str,
260
+ input: JSONSerializableMapping,
261
+ options: TriggerWorkflowOptions,
262
+ ) -> v0_workflow_protos.TriggerWorkflowRequest:
263
+ workflow_run_id = ctx_workflow_run_id.get()
264
+ step_run_id = ctx_step_run_id.get()
265
+ worker_id = ctx_worker_id.get()
266
+ spawn_index = workflow_spawn_indices[workflow_run_id] if workflow_run_id else 0
267
+
268
+ ## Increment the spawn_index for the parent workflow
269
+ if workflow_run_id:
270
+ workflow_spawn_indices[workflow_run_id] += 1
271
+
272
+ desired_worker_id = (
273
+ (options.desired_worker_id or worker_id) if options.sticky else None
274
+ )
275
+ child_index = (
276
+ options.child_index if options.child_index is not None else spawn_index
277
+ )
278
+
279
+ trigger_options = TriggerWorkflowOptions(
280
+ parent_id=options.parent_id or workflow_run_id,
281
+ parent_step_run_id=options.parent_step_run_id or step_run_id,
282
+ child_key=options.child_key,
283
+ child_index=child_index,
284
+ additional_metadata=options.additional_metadata,
285
+ desired_worker_id=desired_worker_id,
286
+ )
287
+
288
+ namespace = options.namespace or self.namespace
289
+
290
+ if namespace != "" and not workflow_name.startswith(self.namespace):
291
+ workflow_name = f"{namespace}{workflow_name}"
292
+
293
+ return self._prepare_workflow_request(workflow_name, input, trigger_options)
294
+
289
295
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
290
296
  @tenacity_retry
291
297
  def run_workflow(
@@ -294,70 +300,125 @@ class AdminClient:
294
300
  input: JSONSerializableMapping,
295
301
  options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
296
302
  ) -> WorkflowRunRef:
303
+ request = self._create_workflow_run_request(workflow_name, input, options)
304
+
305
+ if not self.pooled_workflow_listener:
306
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
307
+
297
308
  try:
298
- if not self.pooled_workflow_listener:
299
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
309
+ resp = cast(
310
+ v0_workflow_protos.TriggerWorkflowResponse,
311
+ self.v0_client.TriggerWorkflow(
312
+ request,
313
+ metadata=get_metadata(self.token),
314
+ ),
315
+ )
316
+ except (grpc.RpcError, grpc.aio.AioRpcError) as e:
317
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
318
+ raise DedupeViolationErr(e.details())
300
319
 
301
- namespace = options.namespace or self.namespace
320
+ return WorkflowRunRef(
321
+ workflow_run_id=resp.workflow_run_id,
322
+ workflow_listener=self.pooled_workflow_listener,
323
+ workflow_run_event_listener=self.listener_client,
324
+ )
302
325
 
303
- if namespace != "" and not workflow_name.startswith(self.namespace):
304
- workflow_name = f"{namespace}{workflow_name}"
326
+ ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
327
+ @tenacity_retry
328
+ async def aio_run_workflow(
329
+ self,
330
+ workflow_name: str,
331
+ input: JSONSerializableMapping,
332
+ options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
333
+ ) -> WorkflowRunRef:
334
+ ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
335
+ ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
336
+ ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
337
+ async with spawn_index_lock:
338
+ request = self._create_workflow_run_request(workflow_name, input, options)
305
339
 
306
- request = self._prepare_workflow_request(workflow_name, input, options)
340
+ if not self.pooled_workflow_listener:
341
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
307
342
 
343
+ try:
308
344
  resp = cast(
309
- workflow_protos.TriggerWorkflowResponse,
310
- self.client.TriggerWorkflow(
345
+ v0_workflow_protos.TriggerWorkflowResponse,
346
+ self.v0_client.TriggerWorkflow(
311
347
  request,
312
348
  metadata=get_metadata(self.token),
313
349
  ),
314
350
  )
315
-
316
- return WorkflowRunRef(
317
- workflow_run_id=resp.workflow_run_id,
318
- workflow_listener=self.pooled_workflow_listener,
319
- workflow_run_event_listener=self.listener_client,
320
- )
321
351
  except (grpc.RpcError, grpc.aio.AioRpcError) as e:
322
352
  if e.code() == grpc.StatusCode.ALREADY_EXISTS:
323
353
  raise DedupeViolationErr(e.details())
324
354
 
325
355
  raise e
326
356
 
327
- def _prepare_workflow_run_request(
328
- self, workflow: WorkflowRunDict, options: TriggerWorkflowOptions
329
- ) -> workflow_protos.TriggerWorkflowRequest:
330
- workflow_name = workflow.workflow_name
331
- input_data = workflow.input
332
- options = workflow.options
333
-
334
- namespace = options.namespace or self.namespace
335
-
336
- if namespace != "" and not workflow_name.startswith(self.namespace):
337
- workflow_name = f"{namespace}{workflow_name}"
338
-
339
- return self._prepare_workflow_request(workflow_name, input_data, options)
357
+ return WorkflowRunRef(
358
+ workflow_run_id=resp.workflow_run_id,
359
+ workflow_listener=self.pooled_workflow_listener,
360
+ workflow_run_event_listener=self.listener_client,
361
+ )
340
362
 
341
363
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
342
364
  @tenacity_retry
343
365
  def run_workflows(
344
366
  self,
345
- workflows: list[WorkflowRunDict],
346
- options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
367
+ workflows: list[WorkflowRunTriggerConfig],
347
368
  ) -> list[WorkflowRunRef]:
348
369
  if not self.pooled_workflow_listener:
349
370
  self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
350
371
 
351
- bulk_request = workflow_protos.BulkTriggerWorkflowRequest(
372
+ bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
352
373
  workflows=[
353
- self._prepare_workflow_run_request(workflow, options)
374
+ self._create_workflow_run_request(
375
+ workflow.workflow_name, workflow.input, workflow.options
376
+ )
354
377
  for workflow in workflows
355
378
  ]
356
379
  )
357
380
 
358
381
  resp = cast(
359
- workflow_protos.BulkTriggerWorkflowResponse,
360
- self.client.BulkTriggerWorkflow(
382
+ v0_workflow_protos.BulkTriggerWorkflowResponse,
383
+ self.v0_client.BulkTriggerWorkflow(
384
+ bulk_request,
385
+ metadata=get_metadata(self.token),
386
+ ),
387
+ )
388
+
389
+ return [
390
+ WorkflowRunRef(
391
+ workflow_run_id=workflow_run_id,
392
+ workflow_listener=self.pooled_workflow_listener,
393
+ workflow_run_event_listener=self.listener_client,
394
+ )
395
+ for workflow_run_id in resp.workflow_run_ids
396
+ ]
397
+
398
+ @tenacity_retry
399
+ async def aio_run_workflows(
400
+ self,
401
+ workflows: list[WorkflowRunTriggerConfig],
402
+ ) -> list[WorkflowRunRef]:
403
+ ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
404
+ ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
405
+ ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
406
+ if not self.pooled_workflow_listener:
407
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
408
+
409
+ async with spawn_index_lock:
410
+ bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
411
+ workflows=[
412
+ self._create_workflow_run_request(
413
+ workflow.workflow_name, workflow.input, workflow.options
414
+ )
415
+ for workflow in workflows
416
+ ]
417
+ )
418
+
419
+ resp = cast(
420
+ v0_workflow_protos.BulkTriggerWorkflowResponse,
421
+ self.v0_client.BulkTriggerWorkflow(
361
422
  bulk_request,
362
423
  metadata=get_metadata(self.token),
363
424
  ),
@@ -1,14 +1,14 @@
1
1
  import asyncio
2
2
  import json
3
3
  import time
4
- from dataclasses import dataclass, field
4
+ from dataclasses import field
5
5
  from enum import Enum
6
- from typing import Any, AsyncGenerator, Optional, cast
6
+ from typing import Any, AsyncGenerator, cast
7
7
 
8
8
  import grpc
9
9
  import grpc.aio
10
10
  from grpc._cython import cygrpc # type: ignore[attr-defined]
11
- from pydantic import BaseModel, ConfigDict, Field, field_validator
11
+ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
12
12
 
13
13
  from hatchet_sdk.clients.event_ts import ThreadSafeEvent, read_with_interrupt
14
14
  from hatchet_sdk.clients.events import proto_timestamp_now
@@ -36,25 +36,29 @@ DEFAULT_ACTION_TIMEOUT = 600 # seconds
36
36
  DEFAULT_ACTION_LISTENER_RETRY_COUNT = 15
37
37
 
38
38
 
39
- @dataclass
40
- class GetActionListenerRequest:
39
+ class GetActionListenerRequest(BaseModel):
40
+ model_config = ConfigDict(arbitrary_types_allowed=True)
41
+
41
42
  worker_name: str
42
43
  services: list[str]
43
44
  actions: list[str]
44
- max_runs: int | None = None
45
- _labels: dict[str, str | int] = field(default_factory=dict)
45
+ slots: int = 100
46
+ raw_labels: dict[str, str | int] = Field(default_factory=dict)
46
47
 
47
- labels: dict[str, WorkerLabels] = field(init=False)
48
+ labels: dict[str, WorkerLabels] = Field(default_factory=dict)
48
49
 
49
- def __post_init__(self) -> None:
50
+ @model_validator(mode="after")
51
+ def validate_labels(self) -> "GetActionListenerRequest":
50
52
  self.labels = {}
51
53
 
52
- for key, value in self._labels.items():
54
+ for key, value in self.raw_labels.items():
53
55
  if isinstance(value, int):
54
56
  self.labels[key] = WorkerLabels(intValue=value)
55
57
  else:
56
58
  self.labels[key] = WorkerLabels(strValue=str(value))
57
59
 
60
+ return self
61
+
58
62
 
59
63
  class ActionPayload(BaseModel):
60
64
  model_config = ConfigDict(extra="allow")
@@ -65,6 +69,7 @@ class ActionPayload(BaseModel):
65
69
  user_data: JSONSerializableMapping = Field(default_factory=dict)
66
70
  step_run_errors: dict[str, str] = Field(default_factory=dict)
67
71
  triggered_by: str | None = None
72
+ triggers: JSONSerializableMapping = Field(default_factory=dict)
68
73
 
69
74
  @field_validator(
70
75
  "input", "parents", "overrides", "user_data", "step_run_errors", mode="before"
@@ -142,30 +147,24 @@ def parse_additional_metadata(additional_metadata: str) -> JSONSerializableMappi
142
147
  return {}
143
148
 
144
149
 
145
- @dataclass
146
150
  class ActionListener:
147
- config: ClientConfig
148
- worker_id: str
151
+ def __init__(self, config: ClientConfig, worker_id: str) -> None:
152
+ self.config = config
153
+ self.worker_id = worker_id
149
154
 
150
- client: DispatcherStub = field(init=False)
151
- aio_client: DispatcherStub = field(init=False)
152
- token: str = field(init=False)
153
- retries: int = field(default=0, init=False)
154
- last_connection_attempt: float = field(default=0, init=False)
155
- last_heartbeat_succeeded: bool = field(default=True, init=False)
156
- time_last_hb_succeeded: float = field(default=9999999999999, init=False)
157
- heartbeat_task: Optional[asyncio.Task[None]] = field(default=None, init=False)
158
- run_heartbeat: bool = field(default=True, init=False)
159
- listen_strategy: str = field(default="v2", init=False)
160
- stop_signal: bool = field(default=False, init=False)
161
-
162
- missed_heartbeats: int = field(default=0, init=False)
163
-
164
- def __post_init__(self) -> None:
165
- self.client = DispatcherStub(new_conn(self.config, False)) # type: ignore[no-untyped-call]
166
155
  self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
167
156
  self.token = self.config.token
168
157
 
158
+ self.retries = 0
159
+ self.last_heartbeat_succeeded = True
160
+ self.time_last_hb_succeeded = 9999999999999.0
161
+ self.last_connection_attempt = 0.0
162
+ self.heartbeat_task: asyncio.Task[None] | None = None
163
+ self.run_heartbeat = True
164
+ self.listen_strategy = "v2"
165
+ self.stop_signal = False
166
+ self.missed_heartbeats = 0
167
+
169
168
  def is_healthy(self) -> bool:
170
169
  return self.last_heartbeat_succeeded
171
170
 
@@ -292,11 +291,16 @@ class ActionListener:
292
291
 
293
292
  self.retries = 0
294
293
 
295
- action_payload = (
296
- {}
297
- if not assigned_action.actionPayload
298
- else self.parse_action_payload(assigned_action.actionPayload)
299
- )
294
+ try:
295
+ action_payload = (
296
+ ActionPayload()
297
+ if not assigned_action.actionPayload
298
+ else ActionPayload.model_validate_json(
299
+ assigned_action.actionPayload
300
+ )
301
+ )
302
+ except (ValueError, json.JSONDecodeError) as e:
303
+ raise ValueError(f"Error decoding payload: {e}")
300
304
 
301
305
  action = Action(
302
306
  tenant_id=assigned_action.tenantId,
@@ -309,7 +313,7 @@ class ActionListener:
309
313
  step_id=assigned_action.stepId,
310
314
  step_run_id=assigned_action.stepRunId,
311
315
  action_id=assigned_action.actionId,
312
- action_payload=ActionPayload.model_validate(action_payload),
316
+ action_payload=action_payload,
313
317
  action_type=convert_proto_enum_to_python(
314
318
  assigned_action.actionType,
315
319
  ActionType,
@@ -352,16 +356,10 @@ class ActionListener:
352
356
 
353
357
  self.retries = self.retries + 1
354
358
 
355
- def parse_action_payload(self, payload: str) -> JSONSerializableMapping:
356
- try:
357
- return cast(JSONSerializableMapping, json.loads(payload))
358
- except json.JSONDecodeError as e:
359
- raise ValueError(f"Error decoding payload: {e}")
360
-
361
359
  async def get_listen_client(
362
360
  self,
363
361
  ) -> grpc.aio.UnaryStreamCall[WorkerListenRequest, AssignedAction]:
364
- current_time = int(time.time())
362
+ current_time = time.time()
365
363
 
366
364
  if (
367
365
  current_time - self.last_connection_attempt
@@ -438,8 +436,10 @@ class ActionListener:
438
436
  timeout=5,
439
437
  metadata=get_metadata(self.token),
440
438
  )
439
+
441
440
  if self.interrupt is not None:
442
441
  self.interrupt.set()
442
+
443
443
  return cast(WorkerUnsubscribeRequest, req)
444
444
  except grpc.RpcError as e:
445
445
  raise Exception(f"Failed to unsubscribe: {e}")