hatchet-sdk 1.2.5__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (60) hide show
  1. hatchet_sdk/__init__.py +7 -5
  2. hatchet_sdk/client.py +14 -6
  3. hatchet_sdk/clients/admin.py +57 -15
  4. hatchet_sdk/clients/dispatcher/action_listener.py +2 -2
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +20 -7
  6. hatchet_sdk/clients/event_ts.py +25 -5
  7. hatchet_sdk/clients/listeners/durable_event_listener.py +125 -0
  8. hatchet_sdk/clients/listeners/pooled_listener.py +255 -0
  9. hatchet_sdk/clients/listeners/workflow_listener.py +62 -0
  10. hatchet_sdk/clients/rest/api/api_token_api.py +24 -24
  11. hatchet_sdk/clients/rest/api/default_api.py +64 -64
  12. hatchet_sdk/clients/rest/api/event_api.py +64 -64
  13. hatchet_sdk/clients/rest/api/github_api.py +8 -8
  14. hatchet_sdk/clients/rest/api/healthcheck_api.py +16 -16
  15. hatchet_sdk/clients/rest/api/log_api.py +16 -16
  16. hatchet_sdk/clients/rest/api/metadata_api.py +24 -24
  17. hatchet_sdk/clients/rest/api/rate_limits_api.py +8 -8
  18. hatchet_sdk/clients/rest/api/slack_api.py +16 -16
  19. hatchet_sdk/clients/rest/api/sns_api.py +24 -24
  20. hatchet_sdk/clients/rest/api/step_run_api.py +56 -56
  21. hatchet_sdk/clients/rest/api/task_api.py +56 -56
  22. hatchet_sdk/clients/rest/api/tenant_api.py +128 -128
  23. hatchet_sdk/clients/rest/api/user_api.py +96 -96
  24. hatchet_sdk/clients/rest/api/worker_api.py +24 -24
  25. hatchet_sdk/clients/rest/api/workflow_api.py +144 -144
  26. hatchet_sdk/clients/rest/api/workflow_run_api.py +48 -48
  27. hatchet_sdk/clients/rest/api/workflow_runs_api.py +40 -40
  28. hatchet_sdk/clients/rest/api_client.py +5 -8
  29. hatchet_sdk/clients/rest/configuration.py +7 -3
  30. hatchet_sdk/clients/rest/models/tenant_step_run_queue_metrics.py +2 -2
  31. hatchet_sdk/clients/rest/models/v1_task_summary.py +5 -0
  32. hatchet_sdk/clients/rest/models/workflow_runs_metrics.py +5 -1
  33. hatchet_sdk/clients/rest/rest.py +160 -111
  34. hatchet_sdk/clients/v1/api_client.py +2 -2
  35. hatchet_sdk/context/context.py +22 -21
  36. hatchet_sdk/features/cron.py +41 -40
  37. hatchet_sdk/features/logs.py +7 -6
  38. hatchet_sdk/features/metrics.py +19 -18
  39. hatchet_sdk/features/runs.py +88 -68
  40. hatchet_sdk/features/scheduled.py +42 -42
  41. hatchet_sdk/features/workers.py +17 -16
  42. hatchet_sdk/features/workflows.py +15 -14
  43. hatchet_sdk/hatchet.py +1 -1
  44. hatchet_sdk/runnables/standalone.py +12 -9
  45. hatchet_sdk/runnables/task.py +66 -2
  46. hatchet_sdk/runnables/types.py +8 -0
  47. hatchet_sdk/runnables/workflow.py +48 -136
  48. hatchet_sdk/waits.py +8 -8
  49. hatchet_sdk/worker/runner/run_loop_manager.py +4 -4
  50. hatchet_sdk/worker/runner/runner.py +22 -11
  51. hatchet_sdk/worker/worker.py +29 -25
  52. hatchet_sdk/workflow_run.py +55 -9
  53. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/METADATA +1 -1
  54. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/RECORD +57 -57
  55. hatchet_sdk/clients/durable_event_listener.py +0 -329
  56. hatchet_sdk/clients/workflow_listener.py +0 -288
  57. hatchet_sdk/utils/aio.py +0 -43
  58. /hatchet_sdk/clients/{run_event_listener.py → listeners/run_event_listener.py} +0 -0
  59. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/WHEEL +0 -0
  60. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/entry_points.txt +0 -0
@@ -11,7 +11,6 @@ from hatchet_sdk.contracts.workflows_pb2 import WorkflowVersion
11
11
  from hatchet_sdk.runnables.task import Task
12
12
  from hatchet_sdk.runnables.types import EmptyModel, R, TWorkflowInput
13
13
  from hatchet_sdk.runnables.workflow import BaseWorkflow, Workflow
14
- from hatchet_sdk.utils.aio import run_async_from_sync
15
14
  from hatchet_sdk.utils.typing import JSONSerializableMapping, is_basemodel_subclass
16
15
  from hatchet_sdk.workflow_run import WorkflowRunRef
17
16
 
@@ -31,11 +30,15 @@ class TaskRunRef(Generic[TWorkflowInput, R]):
31
30
  return self.workflow_run_id
32
31
 
33
32
  async def aio_result(self) -> R:
34
- result = await self._wrr.workflow_listener.aio_result(self._wrr.workflow_run_id)
33
+ result = await self._wrr.workflow_run_listener.aio_result(
34
+ self._wrr.workflow_run_id
35
+ )
35
36
  return self._s._extract_result(result)
36
37
 
37
38
  def result(self) -> R:
38
- return run_async_from_sync(self.aio_result)
39
+ result = self._wrr.result()
40
+
41
+ return self._s._extract_result(result)
39
42
 
40
43
 
41
44
  class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
@@ -129,7 +132,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
129
132
  def schedule(
130
133
  self,
131
134
  run_at: datetime,
132
- input: TWorkflowInput | None = None,
135
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
133
136
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
134
137
  ) -> WorkflowVersion:
135
138
  return self._workflow.schedule(
@@ -141,7 +144,7 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
141
144
  async def aio_schedule(
142
145
  self,
143
146
  run_at: datetime,
144
- input: TWorkflowInput,
147
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
145
148
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
146
149
  ) -> WorkflowVersion:
147
150
  return await self._workflow.aio_schedule(
@@ -154,8 +157,8 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
154
157
  self,
155
158
  cron_name: str,
156
159
  expression: str,
157
- input: TWorkflowInput,
158
- additional_metadata: JSONSerializableMapping,
160
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
161
+ additional_metadata: JSONSerializableMapping = {},
159
162
  ) -> CronWorkflows:
160
163
  return self._workflow.create_cron(
161
164
  cron_name=cron_name,
@@ -168,8 +171,8 @@ class Standalone(BaseWorkflow[TWorkflowInput], Generic[TWorkflowInput, R]):
168
171
  self,
169
172
  cron_name: str,
170
173
  expression: str,
171
- input: TWorkflowInput,
172
- additional_metadata: JSONSerializableMapping,
174
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
175
+ additional_metadata: JSONSerializableMapping = {},
173
176
  ) -> CronWorkflows:
174
177
  return await self._workflow.aio_create_cron(
175
178
  cron_name=cron_name,
@@ -10,7 +10,9 @@ from typing import (
10
10
  )
11
11
 
12
12
  from hatchet_sdk.context.context import Context, DurableContext
13
+ from hatchet_sdk.contracts.v1.shared.condition_pb2 import TaskConditions
13
14
  from hatchet_sdk.contracts.v1.workflows_pb2 import (
15
+ CreateTaskOpts,
14
16
  CreateTaskRateLimit,
15
17
  DesiredWorkerLabels,
16
18
  )
@@ -25,8 +27,15 @@ from hatchet_sdk.runnables.types import (
25
27
  is_durable_sync_fn,
26
28
  is_sync_fn,
27
29
  )
28
- from hatchet_sdk.utils.timedelta_to_expression import Duration
29
- from hatchet_sdk.waits import Condition, OrGroup
30
+ from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
31
+ from hatchet_sdk.waits import (
32
+ Action,
33
+ Condition,
34
+ OrGroup,
35
+ ParentCondition,
36
+ SleepCondition,
37
+ UserEventCondition,
38
+ )
30
39
 
31
40
  if TYPE_CHECKING:
32
41
  from hatchet_sdk.runnables.workflow import Workflow
@@ -142,3 +151,58 @@ class Task(Generic[TWorkflowInput, R]):
142
151
  return await self.fn(workflow_input, cast(Context, ctx)) # type: ignore
143
152
 
144
153
  raise TypeError(f"{self.name} is not an async function. Use `call` instead.")
154
+
155
+ def to_proto(self, service_name: str) -> CreateTaskOpts:
156
+ return CreateTaskOpts(
157
+ readable_id=self.name,
158
+ action=service_name + ":" + self.name,
159
+ timeout=timedelta_to_expr(self.execution_timeout),
160
+ inputs="{}",
161
+ parents=[p.name for p in self.parents],
162
+ retries=self.retries,
163
+ rate_limits=self.rate_limits,
164
+ worker_labels=self.desired_worker_labels,
165
+ backoff_factor=self.backoff_factor,
166
+ backoff_max_seconds=self.backoff_max_seconds,
167
+ concurrency=[t.to_proto() for t in self.concurrency],
168
+ conditions=self._conditions_to_proto(),
169
+ schedule_timeout=timedelta_to_expr(self.schedule_timeout),
170
+ )
171
+
172
+ def _assign_action(self, condition: Condition, action: Action) -> Condition:
173
+ condition.base.action = action
174
+
175
+ return condition
176
+
177
+ def _conditions_to_proto(self) -> TaskConditions:
178
+ wait_for_conditions = [
179
+ self._assign_action(w, Action.QUEUE) for w in self.wait_for
180
+ ]
181
+
182
+ cancel_if_conditions = [
183
+ self._assign_action(c, Action.CANCEL) for c in self.cancel_if
184
+ ]
185
+ skip_if_conditions = [self._assign_action(s, Action.SKIP) for s in self.skip_if]
186
+
187
+ conditions = wait_for_conditions + cancel_if_conditions + skip_if_conditions
188
+
189
+ if len({c.base.readable_data_key for c in conditions}) != len(
190
+ [c.base.readable_data_key for c in conditions]
191
+ ):
192
+ raise ValueError("Conditions must have unique readable data keys.")
193
+
194
+ user_events = [
195
+ c.to_proto() for c in conditions if isinstance(c, UserEventCondition)
196
+ ]
197
+ parent_overrides = [
198
+ c.to_proto() for c in conditions if isinstance(c, ParentCondition)
199
+ ]
200
+ sleep_conditions = [
201
+ c.to_proto() for c in conditions if isinstance(c, SleepCondition)
202
+ ]
203
+
204
+ return TaskConditions(
205
+ parent_override_conditions=parent_overrides,
206
+ sleep_conditions=sleep_conditions,
207
+ user_event_conditions=user_events,
208
+ )
@@ -6,6 +6,7 @@ from typing import Any, Awaitable, Callable, ParamSpec, Type, TypeGuard, TypeVar
6
6
  from pydantic import BaseModel, ConfigDict, Field, StrictInt, model_validator
7
7
 
8
8
  from hatchet_sdk.context.context import Context, DurableContext
9
+ from hatchet_sdk.contracts.v1.workflows_pb2 import Concurrency
9
10
  from hatchet_sdk.utils.timedelta_to_expression import Duration
10
11
  from hatchet_sdk.utils.typing import JSONSerializableMapping
11
12
 
@@ -52,6 +53,13 @@ class ConcurrencyExpression(BaseModel):
52
53
  max_runs: int
53
54
  limit_strategy: ConcurrencyLimitStrategy
54
55
 
56
+ def to_proto(self) -> Concurrency:
57
+ return Concurrency(
58
+ expression=self.expression,
59
+ max_runs=self.max_runs,
60
+ limit_strategy=self.limit_strategy,
61
+ )
62
+
55
63
 
56
64
  TWorkflowInput = TypeVar("TWorkflowInput", bound=BaseModel)
57
65
 
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  from datetime import datetime
3
- from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast, overload
3
+ from typing import TYPE_CHECKING, Any, Callable, Generic, cast
4
4
 
5
5
  from google.protobuf import timestamp_pb2
6
6
  from pydantic import BaseModel
@@ -12,10 +12,7 @@ from hatchet_sdk.clients.admin import (
12
12
  )
13
13
  from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows
14
14
  from hatchet_sdk.context.context import Context, DurableContext
15
- from hatchet_sdk.contracts.v1.shared.condition_pb2 import TaskConditions
16
15
  from hatchet_sdk.contracts.v1.workflows_pb2 import (
17
- Concurrency,
18
- CreateTaskOpts,
19
16
  CreateWorkflowVersionRequest,
20
17
  DesiredWorkerLabels,
21
18
  )
@@ -36,16 +33,9 @@ from hatchet_sdk.runnables.types import (
36
33
  WorkflowConfig,
37
34
  )
38
35
  from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
39
- from hatchet_sdk.utils.timedelta_to_expression import Duration, timedelta_to_expr
36
+ from hatchet_sdk.utils.timedelta_to_expression import Duration
40
37
  from hatchet_sdk.utils.typing import JSONSerializableMapping
41
- from hatchet_sdk.waits import (
42
- Action,
43
- Condition,
44
- OrGroup,
45
- ParentCondition,
46
- SleepCondition,
47
- UserEventCondition,
48
- )
38
+ from hatchet_sdk.waits import Condition, OrGroup
49
39
  from hatchet_sdk.workflow_run import WorkflowRunRef
50
40
 
51
41
  if TYPE_CHECKING:
@@ -78,16 +68,12 @@ class BaseWorkflow(Generic[TWorkflowInput]):
78
68
  self._on_success_task: Task[TWorkflowInput, Any] | None = None
79
69
  self.client = client
80
70
 
81
- def _get_service_name(self, namespace: str) -> str:
82
- return f"{namespace}{self.config.name.lower()}"
83
-
84
- def _create_action_name(
85
- self, namespace: str, step: Task[TWorkflowInput, Any]
86
- ) -> str:
87
- return self._get_service_name(namespace) + ":" + step.name
71
+ @property
72
+ def service_name(self) -> str:
73
+ return f"{self.client.config.namespace}{self.config.name.lower()}"
88
74
 
89
- def _get_name(self, namespace: str) -> str:
90
- return namespace + self.config.name
75
+ def _create_action_name(self, step: Task[TWorkflowInput, Any]) -> str:
76
+ return self.service_name + ":" + step.name
91
77
 
92
78
  def _raise_for_invalid_concurrency(
93
79
  self, concurrency: ConcurrencyExpression
@@ -106,58 +92,6 @@ class BaseWorkflow(Generic[TWorkflowInput]):
106
92
 
107
93
  return True
108
94
 
109
- @overload
110
- def _concurrency_to_proto(self, concurrency: None) -> None: ...
111
-
112
- @overload
113
- def _concurrency_to_proto(
114
- self, concurrency: ConcurrencyExpression
115
- ) -> Concurrency: ...
116
-
117
- def _concurrency_to_proto(
118
- self, concurrency: ConcurrencyExpression | None
119
- ) -> Concurrency | None:
120
- if not concurrency:
121
- return None
122
-
123
- self._raise_for_invalid_concurrency(concurrency)
124
-
125
- return Concurrency(
126
- expression=concurrency.expression,
127
- max_runs=concurrency.max_runs,
128
- limit_strategy=concurrency.limit_strategy,
129
- )
130
-
131
- @overload
132
- def _validate_task(
133
- self, task: "Task[TWorkflowInput, R]", service_name: str
134
- ) -> CreateTaskOpts: ...
135
-
136
- @overload
137
- def _validate_task(self, task: None, service_name: str) -> None: ...
138
-
139
- def _validate_task(
140
- self, task: Union["Task[TWorkflowInput, R]", None], service_name: str
141
- ) -> CreateTaskOpts | None:
142
- if not task:
143
- return None
144
-
145
- return CreateTaskOpts(
146
- readable_id=task.name,
147
- action=service_name + ":" + task.name,
148
- timeout=timedelta_to_expr(task.execution_timeout),
149
- inputs="{}",
150
- parents=[p.name for p in task.parents],
151
- retries=task.retries,
152
- rate_limits=task.rate_limits,
153
- worker_labels=task.desired_worker_labels,
154
- backoff_factor=task.backoff_factor,
155
- backoff_max_seconds=task.backoff_max_seconds,
156
- concurrency=[self._concurrency_to_proto(t) for t in task.concurrency],
157
- conditions=self._conditions_to_proto(task),
158
- schedule_timeout=timedelta_to_expr(task.schedule_timeout),
159
- )
160
-
161
95
  def _validate_priority(self, default_priority: int | None) -> int | None:
162
96
  validated_priority = (
163
97
  max(1, min(3, default_priority)) if default_priority else None
@@ -169,51 +103,14 @@ class BaseWorkflow(Generic[TWorkflowInput]):
169
103
 
170
104
  return validated_priority
171
105
 
172
- def _assign_action(self, condition: Condition, action: Action) -> Condition:
173
- condition.base.action = action
174
-
175
- return condition
176
-
177
- def _conditions_to_proto(self, task: Task[TWorkflowInput, Any]) -> TaskConditions:
178
- wait_for_conditions = [
179
- self._assign_action(w, Action.QUEUE) for w in task.wait_for
180
- ]
181
-
182
- cancel_if_conditions = [
183
- self._assign_action(c, Action.CANCEL) for c in task.cancel_if
184
- ]
185
- skip_if_conditions = [self._assign_action(s, Action.SKIP) for s in task.skip_if]
186
-
187
- conditions = wait_for_conditions + cancel_if_conditions + skip_if_conditions
188
-
189
- if len({c.base.readable_data_key for c in conditions}) != len(
190
- [c.base.readable_data_key for c in conditions]
191
- ):
192
- raise ValueError("Conditions must have unique readable data keys.")
193
-
194
- user_events = [
195
- c.to_pb() for c in conditions if isinstance(c, UserEventCondition)
196
- ]
197
- parent_overrides = [
198
- c.to_pb() for c in conditions if isinstance(c, ParentCondition)
199
- ]
200
- sleep_conditions = [
201
- c.to_pb() for c in conditions if isinstance(c, SleepCondition)
202
- ]
203
-
204
- return TaskConditions(
205
- parent_override_conditions=parent_overrides,
206
- sleep_conditions=sleep_conditions,
207
- user_event_conditions=user_events,
208
- )
209
-
210
106
  def _is_leaf_task(self, task: Task[TWorkflowInput, Any]) -> bool:
211
107
  return not any(task in t.parents for t in self.tasks if task != t)
212
108
 
213
- def _get_create_opts(self, namespace: str) -> CreateWorkflowVersionRequest:
214
- service_name = self._get_service_name(namespace)
109
+ def to_proto(self) -> CreateWorkflowVersionRequest:
110
+ namespace = self.client.config.namespace
111
+ service_name = self.service_name
215
112
 
216
- name = self._get_name(namespace)
113
+ name = self.name
217
114
  event_triggers = [namespace + event for event in self.config.on_events]
218
115
 
219
116
  if self._on_success_task:
@@ -223,10 +120,12 @@ class BaseWorkflow(Generic[TWorkflowInput]):
223
120
  if task.type == StepType.DEFAULT and self._is_leaf_task(task)
224
121
  ]
225
122
 
226
- on_success_task = self._validate_task(self._on_success_task, service_name)
123
+ on_success_task = (
124
+ t.to_proto(service_name) if (t := self._on_success_task) else None
125
+ )
227
126
 
228
127
  tasks = [
229
- self._validate_task(task, service_name)
128
+ task.to_proto(service_name)
230
129
  for task in self.tasks
231
130
  if task.type == StepType.DEFAULT
232
131
  ]
@@ -234,7 +133,9 @@ class BaseWorkflow(Generic[TWorkflowInput]):
234
133
  if on_success_task:
235
134
  tasks += [on_success_task]
236
135
 
237
- on_failure_task = self._validate_task(self._on_failure_task, service_name)
136
+ on_failure_task = (
137
+ t.to_proto(service_name) if (t := self._on_failure_task) else None
138
+ )
238
139
 
239
140
  return CreateWorkflowVersionRequest(
240
141
  name=name,
@@ -243,7 +144,7 @@ class BaseWorkflow(Generic[TWorkflowInput]):
243
144
  event_triggers=event_triggers,
244
145
  cron_triggers=self.config.on_crons,
245
146
  tasks=tasks,
246
- concurrency=self._concurrency_to_proto(self.config.concurrency),
147
+ concurrency=(c.to_proto() if (c := self.config.concurrency) else None),
247
148
  ## TODO: Fix this
248
149
  cron_input=None,
249
150
  on_failure_task=on_failure_task,
@@ -274,21 +175,32 @@ class BaseWorkflow(Generic[TWorkflowInput]):
274
175
 
275
176
  @property
276
177
  def name(self) -> str:
277
- return self._get_name(self.client.config.namespace)
178
+ return self.client.config.namespace + self.config.name
278
179
 
279
180
  def create_bulk_run_item(
280
181
  self,
281
- input: TWorkflowInput | None = None,
182
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
282
183
  key: str | None = None,
283
184
  options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
284
185
  ) -> WorkflowRunTriggerConfig:
285
186
  return WorkflowRunTriggerConfig(
286
187
  workflow_name=self.config.name,
287
- input=input.model_dump() if input else {},
188
+ input=self._serialize_input(input),
288
189
  options=options,
289
190
  key=key,
290
191
  )
291
192
 
193
+ def _serialize_input(self, input: TWorkflowInput | None) -> JSONSerializableMapping:
194
+ if not input:
195
+ return {}
196
+
197
+ if isinstance(input, BaseModel):
198
+ return input.model_dump(mode="json")
199
+
200
+ raise ValueError(
201
+ f"Input must be a BaseModel or `None`, got {type(input)} instead."
202
+ )
203
+
292
204
 
293
205
  class Workflow(BaseWorkflow[TWorkflowInput]):
294
206
  """
@@ -303,7 +215,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
303
215
  ) -> WorkflowRunRef:
304
216
  return self.client._client.admin.run_workflow(
305
217
  workflow_name=self.config.name,
306
- input=input.model_dump() if input else {},
218
+ input=self._serialize_input(input),
307
219
  options=options,
308
220
  )
309
221
 
@@ -314,7 +226,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
314
226
  ) -> dict[str, Any]:
315
227
  ref = self.client._client.admin.run_workflow(
316
228
  workflow_name=self.config.name,
317
- input=input.model_dump() if input else {},
229
+ input=self._serialize_input(input),
318
230
  options=options,
319
231
  )
320
232
 
@@ -327,7 +239,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
327
239
  ) -> WorkflowRunRef:
328
240
  return await self.client._client.admin.aio_run_workflow(
329
241
  workflow_name=self.config.name,
330
- input=input.model_dump() if input else {},
242
+ input=self._serialize_input(input),
331
243
  options=options,
332
244
  )
333
245
 
@@ -338,7 +250,7 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
338
250
  ) -> dict[str, Any]:
339
251
  ref = await self.client._client.admin.aio_run_workflow(
340
252
  workflow_name=self.config.name,
341
- input=input.model_dump() if input else {},
253
+ input=self._serialize_input(input),
342
254
  options=options,
343
255
  )
344
256
 
@@ -383,26 +295,26 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
383
295
  def schedule(
384
296
  self,
385
297
  run_at: datetime,
386
- input: TWorkflowInput | None = None,
298
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
387
299
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
388
300
  ) -> WorkflowVersion:
389
301
  return self.client._client.admin.schedule_workflow(
390
302
  name=self.config.name,
391
303
  schedules=cast(list[datetime | timestamp_pb2.Timestamp], [run_at]),
392
- input=input.model_dump() if input else {},
304
+ input=self._serialize_input(input),
393
305
  options=options,
394
306
  )
395
307
 
396
308
  async def aio_schedule(
397
309
  self,
398
310
  run_at: datetime,
399
- input: TWorkflowInput,
311
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
400
312
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
401
313
  ) -> WorkflowVersion:
402
314
  return await self.client._client.admin.aio_schedule_workflow(
403
315
  name=self.config.name,
404
316
  schedules=cast(list[datetime | timestamp_pb2.Timestamp], [run_at]),
405
- input=input.model_dump(),
317
+ input=self._serialize_input(input),
406
318
  options=options,
407
319
  )
408
320
 
@@ -410,14 +322,14 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
410
322
  self,
411
323
  cron_name: str,
412
324
  expression: str,
413
- input: TWorkflowInput,
414
- additional_metadata: JSONSerializableMapping,
325
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
326
+ additional_metadata: JSONSerializableMapping = {},
415
327
  ) -> CronWorkflows:
416
328
  return self.client.cron.create(
417
329
  workflow_name=self.config.name,
418
330
  cron_name=cron_name,
419
331
  expression=expression,
420
- input=input.model_dump(),
332
+ input=self._serialize_input(input),
421
333
  additional_metadata=additional_metadata,
422
334
  )
423
335
 
@@ -425,14 +337,14 @@ class Workflow(BaseWorkflow[TWorkflowInput]):
425
337
  self,
426
338
  cron_name: str,
427
339
  expression: str,
428
- input: TWorkflowInput,
429
- additional_metadata: JSONSerializableMapping,
340
+ input: TWorkflowInput = cast(TWorkflowInput, EmptyModel()),
341
+ additional_metadata: JSONSerializableMapping = {},
430
342
  ) -> CronWorkflows:
431
343
  return await self.client.cron.aio_create(
432
344
  workflow_name=self.config.name,
433
345
  cron_name=cron_name,
434
346
  expression=expression,
435
- input=input.model_dump(),
347
+ input=self._serialize_input(input),
436
348
  additional_metadata=additional_metadata,
437
349
  )
438
350
 
hatchet_sdk/waits.py CHANGED
@@ -37,7 +37,7 @@ class BaseCondition(BaseModel):
37
37
  or_group_id: str = Field(default_factory=generate_or_group_id)
38
38
  expression: str | None = None
39
39
 
40
- def to_pb(self) -> BaseMatchCondition:
40
+ def to_proto(self) -> BaseMatchCondition:
41
41
  return BaseMatchCondition(
42
42
  readable_data_key=self.readable_data_key,
43
43
  action=convert_python_enum_to_proto(self.action, ProtoAction), # type: ignore[arg-type]
@@ -51,7 +51,7 @@ class Condition(ABC):
51
51
  self.base = base
52
52
 
53
53
  @abstractmethod
54
- def to_pb(
54
+ def to_proto(
55
55
  self,
56
56
  ) -> UserEventMatchCondition | ParentOverrideMatchCondition | SleepMatchCondition:
57
57
  pass
@@ -67,9 +67,9 @@ class SleepCondition(Condition):
67
67
 
68
68
  self.duration = duration
69
69
 
70
- def to_pb(self) -> SleepMatchCondition:
70
+ def to_proto(self) -> SleepMatchCondition:
71
71
  return SleepMatchCondition(
72
- base=self.base.to_pb(),
72
+ base=self.base.to_proto(),
73
73
  sleep_for=timedelta_to_expr(self.duration),
74
74
  )
75
75
 
@@ -86,9 +86,9 @@ class UserEventCondition(Condition):
86
86
  self.event_key = event_key
87
87
  self.expression = expression
88
88
 
89
- def to_pb(self) -> UserEventMatchCondition:
89
+ def to_proto(self) -> UserEventMatchCondition:
90
90
  return UserEventMatchCondition(
91
- base=self.base.to_pb(),
91
+ base=self.base.to_proto(),
92
92
  user_event_key=self.event_key,
93
93
  )
94
94
 
@@ -103,9 +103,9 @@ class ParentCondition(Condition):
103
103
 
104
104
  self.parent = parent
105
105
 
106
- def to_pb(self) -> ParentOverrideMatchCondition:
106
+ def to_proto(self) -> ParentOverrideMatchCondition:
107
107
  return ParentOverrideMatchCondition(
108
- base=self.base.to_pb(),
108
+ base=self.base.to_proto(),
109
109
  parent_readable_id=self.parent.name,
110
110
  )
111
111
 
@@ -55,17 +55,17 @@ class WorkerActionRunLoopManager:
55
55
  self.client = Client(config=self.config, debug=self.debug)
56
56
  self.start()
57
57
 
58
- def start(self, retry_count: int = 1) -> None:
59
- k = self.loop.create_task(self.aio_start(retry_count)) # noqa: F841
58
+ def start(self) -> None:
59
+ k = self.loop.create_task(self.aio_start()) # noqa: F841
60
60
 
61
61
  async def aio_start(self, retry_count: int = 1) -> None:
62
62
  await capture_logs(
63
63
  self.client.log_interceptor,
64
64
  self.client.event,
65
65
  self._async_start,
66
- )(retry_count=retry_count)
66
+ )()
67
67
 
68
- async def _async_start(self, retry_count: int = 1) -> None:
68
+ async def _async_start(self) -> None:
69
69
  logger.info("starting runner...")
70
70
  self.loop = asyncio.get_running_loop()
71
71
  # needed for graceful termination
@@ -16,9 +16,10 @@ from hatchet_sdk.client import Client
16
16
  from hatchet_sdk.clients.admin import AdminClient
17
17
  from hatchet_sdk.clients.dispatcher.action_listener import Action, ActionType
18
18
  from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
19
- from hatchet_sdk.clients.durable_event_listener import DurableEventListener
20
- from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
21
- from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
19
+ from hatchet_sdk.clients.events import EventClient
20
+ from hatchet_sdk.clients.listeners.durable_event_listener import DurableEventListener
21
+ from hatchet_sdk.clients.listeners.run_event_listener import RunEventListenerClient
22
+ from hatchet_sdk.clients.listeners.workflow_listener import PooledWorkflowRunListener
22
23
  from hatchet_sdk.config import ClientConfig
23
24
  from hatchet_sdk.context.context import Context, DurableContext
24
25
  from hatchet_sdk.context.worker_context import WorkerContext
@@ -31,6 +32,7 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
31
32
  STEP_EVENT_TYPE_STARTED,
32
33
  )
33
34
  from hatchet_sdk.exceptions import NonRetryableException
35
+ from hatchet_sdk.features.runs import RunsClient
34
36
  from hatchet_sdk.logger import logger
35
37
  from hatchet_sdk.runnables.contextvars import (
36
38
  ctx_step_run_id,
@@ -66,7 +68,7 @@ class Runner:
66
68
  ):
67
69
  # We store the config so we can dynamically create clients for the dispatcher client.
68
70
  self.config = config
69
- self.client = Client(config)
71
+
70
72
  self.slots = slots
71
73
  self.tasks: dict[str, asyncio.Task[Any]] = {} # Store run ids and futures
72
74
  self.contexts: dict[str, Context] = {} # Store run ids and contexts
@@ -82,12 +84,21 @@ class Runner:
82
84
  self.killing = False
83
85
  self.handle_kill = handle_kill
84
86
 
85
- # We need to initialize a new admin and dispatcher client *after* we've started the event loop,
86
- # otherwise the grpc.aio methods will use a different event loop and we'll get a bunch of errors.
87
87
  self.dispatcher_client = DispatcherClient(self.config)
88
- self.admin_client = AdminClient(self.config)
89
88
  self.workflow_run_event_listener = RunEventListenerClient(self.config)
90
- self.client.workflow_listener = PooledWorkflowRunListener(self.config)
89
+ self.workflow_listener = PooledWorkflowRunListener(self.config)
90
+ self.runs_client = RunsClient(
91
+ config=self.config,
92
+ workflow_run_event_listener=self.workflow_run_event_listener,
93
+ workflow_run_listener=self.workflow_listener,
94
+ )
95
+ self.admin_client = AdminClient(
96
+ self.config,
97
+ self.workflow_listener,
98
+ self.workflow_run_event_listener,
99
+ self.runs_client,
100
+ )
101
+ self.event_client = EventClient(self.config)
91
102
  self.durable_event_listener = DurableEventListener(self.config)
92
103
 
93
104
  self.worker_context = WorkerContext(
@@ -291,11 +302,11 @@ class Runner:
291
302
  action=action,
292
303
  dispatcher_client=self.dispatcher_client,
293
304
  admin_client=self.admin_client,
294
- event_client=self.client.event,
305
+ event_client=self.event_client,
295
306
  durable_event_listener=self.durable_event_listener,
296
307
  worker=self.worker_context,
297
308
  validator_registry=self.validator_registry,
298
- runs_client=self.client.runs,
309
+ runs_client=self.runs_client,
299
310
  )
300
311
 
301
312
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@@ -430,7 +441,7 @@ class Runner:
430
441
  # check if thread is still running, if so, print a warning
431
442
  if run_id in self.threads:
432
443
  thread = self.threads.get(run_id)
433
- if thread and self.client.config.enable_force_kill_sync_threads:
444
+ if thread and self.config.enable_force_kill_sync_threads:
434
445
  self.force_kill_thread(thread)
435
446
  await asyncio.sleep(1)
436
447