hatchet-sdk 1.0.0__py3-none-any.whl → 1.0.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (65) hide show
  1. hatchet_sdk/__init__.py +27 -16
  2. hatchet_sdk/client.py +13 -63
  3. hatchet_sdk/clients/admin.py +203 -124
  4. hatchet_sdk/clients/dispatcher/action_listener.py +42 -42
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +18 -16
  6. hatchet_sdk/clients/durable_event_listener.py +327 -0
  7. hatchet_sdk/clients/rest/__init__.py +12 -1
  8. hatchet_sdk/clients/rest/api/log_api.py +258 -0
  9. hatchet_sdk/clients/rest/api/task_api.py +32 -6
  10. hatchet_sdk/clients/rest/api/workflow_runs_api.py +626 -0
  11. hatchet_sdk/clients/rest/models/__init__.py +12 -1
  12. hatchet_sdk/clients/rest/models/v1_log_line.py +94 -0
  13. hatchet_sdk/clients/rest/models/v1_log_line_level.py +39 -0
  14. hatchet_sdk/clients/rest/models/v1_log_line_list.py +110 -0
  15. hatchet_sdk/clients/rest/models/v1_task_summary.py +80 -64
  16. hatchet_sdk/clients/rest/models/v1_trigger_workflow_run_request.py +95 -0
  17. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name.py +98 -0
  18. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name_list.py +114 -0
  19. hatchet_sdk/clients/rest/models/workflow_run_shape_item_for_workflow_run_details.py +9 -4
  20. hatchet_sdk/clients/rest_client.py +21 -0
  21. hatchet_sdk/clients/run_event_listener.py +0 -1
  22. hatchet_sdk/context/context.py +85 -147
  23. hatchet_sdk/contracts/dispatcher_pb2_grpc.py +1 -1
  24. hatchet_sdk/contracts/events_pb2.py +2 -2
  25. hatchet_sdk/contracts/events_pb2_grpc.py +1 -1
  26. hatchet_sdk/contracts/v1/dispatcher_pb2.py +36 -0
  27. hatchet_sdk/contracts/v1/dispatcher_pb2.pyi +38 -0
  28. hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +145 -0
  29. hatchet_sdk/contracts/v1/shared/condition_pb2.py +39 -0
  30. hatchet_sdk/contracts/v1/shared/condition_pb2.pyi +72 -0
  31. hatchet_sdk/contracts/v1/shared/condition_pb2_grpc.py +29 -0
  32. hatchet_sdk/contracts/v1/workflows_pb2.py +67 -0
  33. hatchet_sdk/contracts/v1/workflows_pb2.pyi +228 -0
  34. hatchet_sdk/contracts/v1/workflows_pb2_grpc.py +234 -0
  35. hatchet_sdk/contracts/workflows_pb2_grpc.py +1 -1
  36. hatchet_sdk/features/cron.py +3 -3
  37. hatchet_sdk/features/scheduled.py +2 -2
  38. hatchet_sdk/hatchet.py +427 -151
  39. hatchet_sdk/opentelemetry/instrumentor.py +8 -13
  40. hatchet_sdk/rate_limit.py +33 -39
  41. hatchet_sdk/runnables/contextvars.py +12 -0
  42. hatchet_sdk/runnables/standalone.py +194 -0
  43. hatchet_sdk/runnables/task.py +144 -0
  44. hatchet_sdk/runnables/types.py +138 -0
  45. hatchet_sdk/runnables/workflow.py +764 -0
  46. hatchet_sdk/utils/aio_utils.py +0 -79
  47. hatchet_sdk/utils/proto_enums.py +0 -7
  48. hatchet_sdk/utils/timedelta_to_expression.py +23 -0
  49. hatchet_sdk/utils/typing.py +2 -2
  50. hatchet_sdk/v0/clients/rest_client.py +9 -0
  51. hatchet_sdk/v0/worker/action_listener_process.py +18 -2
  52. hatchet_sdk/waits.py +120 -0
  53. hatchet_sdk/worker/action_listener_process.py +64 -30
  54. hatchet_sdk/worker/runner/run_loop_manager.py +35 -25
  55. hatchet_sdk/worker/runner/runner.py +72 -49
  56. hatchet_sdk/worker/runner/utils/capture_logs.py +3 -11
  57. hatchet_sdk/worker/worker.py +155 -118
  58. hatchet_sdk/workflow_run.py +4 -5
  59. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/METADATA +1 -2
  60. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/RECORD +62 -42
  61. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/entry_points.txt +2 -0
  62. hatchet_sdk/semver.py +0 -30
  63. hatchet_sdk/worker/runner/utils/error_with_traceback.py +0 -6
  64. hatchet_sdk/workflow.py +0 -527
  65. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.0a1.dist-info}/WHEEL +0 -0
hatchet_sdk/__init__.py CHANGED
@@ -1,10 +1,9 @@
1
- from hatchet_sdk.client import new_client
2
1
  from hatchet_sdk.clients.admin import (
3
- ChildTriggerWorkflowOptions,
4
2
  DedupeViolationErr,
5
3
  ScheduleTriggerWorkflowOptions,
6
4
  TriggerWorkflowOptions,
7
5
  )
6
+ from hatchet_sdk.clients.durable_event_listener import RegisterDurableEventRequest
8
7
  from hatchet_sdk.clients.events import PushEventOptions
9
8
  from hatchet_sdk.clients.rest.models.accept_invite_request import AcceptInviteRequest
10
9
 
@@ -76,7 +75,6 @@ from hatchet_sdk.clients.rest.models.pull_request_state import PullRequestState
76
75
  from hatchet_sdk.clients.rest.models.reject_invite_request import RejectInviteRequest
77
76
  from hatchet_sdk.clients.rest.models.replay_event_request import ReplayEventRequest
78
77
  from hatchet_sdk.clients.rest.models.rerun_step_run_request import RerunStepRunRequest
79
- from hatchet_sdk.clients.rest.models.step import Step
80
78
  from hatchet_sdk.clients.rest.models.step_run import StepRun
81
79
  from hatchet_sdk.clients.rest.models.step_run_diff import StepRunDiff
82
80
  from hatchet_sdk.clients.rest.models.step_run_status import StepRunStatus
@@ -130,7 +128,7 @@ from hatchet_sdk.clients.run_event_listener import (
130
128
  WorkflowRunEventType,
131
129
  )
132
130
  from hatchet_sdk.config import ClientConfig
133
- from hatchet_sdk.context.context import Context
131
+ from hatchet_sdk.context.context import Context, DurableContext
134
132
  from hatchet_sdk.context.worker_context import WorkerContext
135
133
  from hatchet_sdk.contracts.workflows_pb2 import (
136
134
  CreateWorkflowVersionOpts,
@@ -138,15 +136,24 @@ from hatchet_sdk.contracts.workflows_pb2 import (
138
136
  WorkerLabelComparator,
139
137
  )
140
138
  from hatchet_sdk.hatchet import Hatchet
141
- from hatchet_sdk.utils.aio_utils import sync_to_async
142
- from hatchet_sdk.worker.worker import Worker, WorkerStartOptions, WorkerStatus
143
- from hatchet_sdk.workflow import (
144
- BaseWorkflow,
139
+ from hatchet_sdk.runnables.task import Task
140
+ from hatchet_sdk.runnables.types import (
145
141
  ConcurrencyExpression,
146
142
  ConcurrencyLimitStrategy,
143
+ EmptyModel,
147
144
  StickyStrategy,
145
+ TaskDefaults,
148
146
  WorkflowConfig,
149
147
  )
148
+ from hatchet_sdk.waits import (
149
+ Condition,
150
+ OrGroup,
151
+ ParentCondition,
152
+ SleepCondition,
153
+ UserEventCondition,
154
+ or_,
155
+ )
156
+ from hatchet_sdk.worker.worker import Worker, WorkerStartOptions, WorkerStatus
150
157
 
151
158
  __all__ = [
152
159
  "AcceptInviteRequest",
@@ -191,11 +198,9 @@ __all__ = [
191
198
  "RejectInviteRequest",
192
199
  "ReplayEventRequest",
193
200
  "RerunStepRunRequest",
194
- "Step",
195
201
  "StepRun",
196
202
  "StepRunDiff",
197
203
  "StepRunStatus",
198
- "sync_to_async",
199
204
  "Tenant",
200
205
  "TenantInvite",
201
206
  "TenantInviteList",
@@ -231,8 +236,6 @@ __all__ = [
231
236
  "CreateWorkflowVersionOpts",
232
237
  "RateLimitDuration",
233
238
  "StickyStrategy",
234
- "new_client",
235
- "ChildTriggerWorkflowOptions",
236
239
  "DedupeViolationErr",
237
240
  "ScheduleTriggerWorkflowOptions",
238
241
  "TriggerWorkflowOptions",
@@ -243,14 +246,22 @@ __all__ = [
243
246
  "WorkerContext",
244
247
  "ClientConfig",
245
248
  "Hatchet",
246
- "concurrency",
247
- "on_failure_step",
248
- "step",
249
249
  "workflow",
250
250
  "Worker",
251
251
  "WorkerStartOptions",
252
252
  "WorkerStatus",
253
253
  "ConcurrencyExpression",
254
- "BaseWorkflow",
254
+ "Workflow",
255
255
  "WorkflowConfig",
256
+ "Task",
257
+ "EmptyModel",
258
+ "Condition",
259
+ "OrGroup",
260
+ "or_",
261
+ "SleepCondition",
262
+ "UserEventCondition",
263
+ "ParentCondition",
264
+ "DurableContext",
265
+ "RegisterDurableEventRequest",
266
+ "TaskDefaults",
256
267
  ]
hatchet_sdk/client.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import asyncio
2
- from typing import Callable
3
2
 
4
3
  import grpc
5
4
 
@@ -14,63 +13,14 @@ from hatchet_sdk.connection import new_conn
14
13
 
15
14
 
16
15
  class Client:
17
- @classmethod
18
- def from_environment(
19
- cls,
20
- defaults: ClientConfig = ClientConfig(),
21
- debug: bool = False,
22
- *opts_functions: Callable[[ClientConfig], None],
23
- ) -> "Client":
24
- try:
25
- loop = asyncio.get_running_loop()
26
- except RuntimeError:
27
- loop = asyncio.new_event_loop()
28
- asyncio.set_event_loop(loop)
29
-
30
- for opt_function in opts_functions:
31
- opt_function(defaults)
32
-
33
- return cls.from_config(defaults, debug)
34
-
35
- @classmethod
36
- def from_config(
37
- cls,
38
- config: ClientConfig = ClientConfig(),
39
- debug: bool = False,
40
- ) -> "Client":
41
- try:
42
- loop = asyncio.get_running_loop()
43
- except RuntimeError:
44
- loop = asyncio.new_event_loop()
45
- asyncio.set_event_loop(loop)
46
-
47
- conn: grpc.Channel = new_conn(config, False)
48
-
49
- # Instantiate clients
50
- event_client = new_event(conn, config)
51
- admin_client = AdminClient(config)
52
- dispatcher_client = DispatcherClient(config)
53
- rest_client = RestApi(config.server_url, config.token, config.tenant_id)
54
- workflow_listener = None # Initialize this if needed
55
-
56
- return cls(
57
- event_client,
58
- admin_client,
59
- dispatcher_client,
60
- workflow_listener,
61
- rest_client,
62
- config,
63
- debug,
64
- )
65
-
66
16
  def __init__(
67
17
  self,
68
- event_client: EventClient,
69
- admin_client: AdminClient,
70
- dispatcher_client: DispatcherClient,
71
- workflow_listener: PooledWorkflowRunListener | None,
72
- rest_client: RestApi,
73
18
  config: ClientConfig,
19
+ event_client: EventClient | None = None,
20
+ admin_client: AdminClient | None = None,
21
+ dispatcher_client: DispatcherClient | None = None,
22
+ workflow_listener: PooledWorkflowRunListener | None | None = None,
23
+ rest_client: RestApi | None = None,
74
24
  debug: bool = False,
75
25
  ):
76
26
  try:
@@ -79,16 +29,16 @@ class Client:
79
29
  loop = asyncio.new_event_loop()
80
30
  asyncio.set_event_loop(loop)
81
31
 
82
- self.admin = admin_client
83
- self.dispatcher = dispatcher_client
84
- self.event = event_client
85
- self.rest = rest_client
32
+ conn: grpc.Channel = new_conn(config, False)
33
+
86
34
  self.config = config
35
+ self.admin = admin_client or AdminClient(config)
36
+ self.dispatcher = dispatcher_client or DispatcherClient(config)
37
+ self.event = event_client or new_event(conn, config)
38
+ self.rest = rest_client or RestApi(
39
+ config.server_url, config.token, config.tenant_id
40
+ )
87
41
  self.listener = RunEventListenerClient(config)
88
42
  self.workflow_listener = workflow_listener
89
43
  self.logInterceptor = config.logger
90
44
  self.debug = debug
91
-
92
-
93
- new_client = Client.from_environment
94
- new_client_raw = Client.from_config
@@ -5,18 +5,27 @@ from typing import Union, cast
5
5
 
6
6
  import grpc
7
7
  from google.protobuf import timestamp_pb2
8
- from pydantic import BaseModel, Field
8
+ from pydantic import BaseModel, ConfigDict, Field, field_validator
9
9
 
10
10
  from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
11
11
  from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
12
12
  from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
13
13
  from hatchet_sdk.config import ClientConfig
14
14
  from hatchet_sdk.connection import new_conn
15
- from hatchet_sdk.contracts import workflows_pb2 as workflow_protos
15
+ from hatchet_sdk.contracts import workflows_pb2 as v0_workflow_protos
16
+ from hatchet_sdk.contracts.v1 import workflows_pb2 as workflow_protos
17
+ from hatchet_sdk.contracts.v1.workflows_pb2_grpc import AdminServiceStub
16
18
  from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub
17
19
  from hatchet_sdk.metadata import get_metadata
18
20
  from hatchet_sdk.rate_limit import RateLimitDuration
19
- from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto, maybe_int_to_str
21
+ from hatchet_sdk.runnables.contextvars import (
22
+ ctx_step_run_id,
23
+ ctx_worker_id,
24
+ ctx_workflow_run_id,
25
+ spawn_index_lock,
26
+ workflow_spawn_indices,
27
+ )
28
+ from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
20
29
  from hatchet_sdk.utils.typing import JSONSerializableMapping
21
30
  from hatchet_sdk.workflow_run import WorkflowRunRef
22
31
 
@@ -29,28 +38,19 @@ class ScheduleTriggerWorkflowOptions(BaseModel):
29
38
  namespace: str | None = None
30
39
 
31
40
 
32
- class ChildTriggerWorkflowOptions(BaseModel):
33
- additional_metadata: JSONSerializableMapping = Field(default_factory=dict)
34
- sticky: bool = False
35
-
36
-
37
- class ChildWorkflowRunDict(BaseModel):
38
- workflow_name: str
39
- input: JSONSerializableMapping
40
- options: ChildTriggerWorkflowOptions
41
- key: str | None = None
42
-
43
-
44
41
  class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions):
45
42
  additional_metadata: JSONSerializableMapping = Field(default_factory=dict)
46
43
  desired_worker_id: str | None = None
47
44
  namespace: str | None = None
45
+ sticky: bool = False
46
+ key: str | None = None
48
47
 
49
48
 
50
- class WorkflowRunDict(BaseModel):
49
+ class WorkflowRunTriggerConfig(BaseModel):
51
50
  workflow_name: str
52
51
  input: JSONSerializableMapping
53
52
  options: TriggerWorkflowOptions
53
+ key: str | None = None
54
54
 
55
55
 
56
56
  class DedupeViolationErr(Exception):
@@ -63,55 +63,69 @@ class AdminClient:
63
63
  def __init__(self, config: ClientConfig):
64
64
  conn = new_conn(config, False)
65
65
  self.config = config
66
- self.client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
66
+ self.client = AdminServiceStub(conn) # type: ignore[no-untyped-call]
67
+ self.v0_client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
67
68
  self.token = config.token
68
69
  self.listener_client = RunEventListenerClient(config=config)
69
70
  self.namespace = config.namespace
70
71
 
71
72
  self.pooled_workflow_listener: PooledWorkflowRunListener | None = None
72
73
 
74
+ class TriggerWorkflowRequest(BaseModel):
75
+ model_config = ConfigDict(extra="ignore")
76
+
77
+ parent_id: str | None = None
78
+ parent_step_run_id: str | None = None
79
+ child_index: int | None = None
80
+ child_key: str | None = None
81
+ additional_metadata: str | None = None
82
+ desired_worker_id: str | None = None
83
+ priority: int | None = None
84
+
85
+ @field_validator("additional_metadata", mode="before")
86
+ @classmethod
87
+ def validate_additional_metadata(
88
+ cls, v: JSONSerializableMapping | None
89
+ ) -> bytes | None:
90
+ if not v:
91
+ return None
92
+
93
+ try:
94
+ return json.dumps(v).encode("utf-8")
95
+ except json.JSONDecodeError as e:
96
+ raise ValueError(f"Error encoding payload: {e}")
97
+
73
98
  def _prepare_workflow_request(
74
99
  self,
75
100
  workflow_name: str,
76
101
  input: JSONSerializableMapping,
77
102
  options: TriggerWorkflowOptions,
78
- ) -> workflow_protos.TriggerWorkflowRequest:
103
+ ) -> v0_workflow_protos.TriggerWorkflowRequest:
79
104
  try:
80
105
  payload_data = json.dumps(input)
81
- _options = options.model_dump()
82
-
83
- _options.pop("namespace")
84
-
85
- try:
86
- _options = {
87
- **_options,
88
- "additional_metadata": json.dumps(
89
- options.additional_metadata
90
- ).encode("utf-8"),
91
- }
92
- except json.JSONDecodeError as e:
93
- raise ValueError(f"Error encoding payload: {e}")
94
-
95
- return workflow_protos.TriggerWorkflowRequest(
96
- name=workflow_name, input=payload_data, **_options
97
- )
98
106
  except json.JSONDecodeError as e:
99
107
  raise ValueError(f"Error encoding payload: {e}")
100
108
 
109
+ _options = self.TriggerWorkflowRequest.model_validate(
110
+ options.model_dump()
111
+ ).model_dump()
112
+
113
+ return v0_workflow_protos.TriggerWorkflowRequest(
114
+ name=workflow_name, input=payload_data, **_options
115
+ )
116
+
101
117
  def _prepare_put_workflow_request(
102
118
  self,
103
119
  name: str,
104
- workflow: workflow_protos.CreateWorkflowVersionOpts,
105
- overrides: workflow_protos.CreateWorkflowVersionOpts | None = None,
106
- ) -> workflow_protos.PutWorkflowRequest:
120
+ workflow: workflow_protos.CreateWorkflowVersionRequest,
121
+ overrides: workflow_protos.CreateWorkflowVersionRequest | None = None,
122
+ ) -> workflow_protos.CreateWorkflowVersionRequest:
107
123
  if overrides is not None:
108
124
  workflow.MergeFrom(overrides)
109
125
 
110
126
  workflow.name = name
111
127
 
112
- return workflow_protos.PutWorkflowRequest(
113
- opts=workflow,
114
- )
128
+ return workflow
115
129
 
116
130
  def _parse_schedule(
117
131
  self, schedule: datetime | timestamp_pb2.Timestamp
@@ -134,51 +148,21 @@ class AdminClient:
134
148
  schedules: list[Union[datetime, timestamp_pb2.Timestamp]],
135
149
  input: JSONSerializableMapping = {},
136
150
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
137
- ) -> workflow_protos.ScheduleWorkflowRequest:
138
- return workflow_protos.ScheduleWorkflowRequest(
151
+ ) -> v0_workflow_protos.ScheduleWorkflowRequest:
152
+ return v0_workflow_protos.ScheduleWorkflowRequest(
139
153
  name=name,
140
154
  schedules=[self._parse_schedule(schedule) for schedule in schedules],
141
155
  input=json.dumps(input),
142
156
  **options.model_dump(),
143
157
  )
144
158
 
145
- ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
146
- @tenacity_retry
147
- async def aio_run_workflow(
148
- self,
149
- workflow_name: str,
150
- input: JSONSerializableMapping,
151
- options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
152
- ) -> WorkflowRunRef:
153
- ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
154
- ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
155
- ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
156
- if not self.pooled_workflow_listener:
157
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
158
-
159
- return await asyncio.to_thread(self.run_workflow, workflow_name, input, options)
160
-
161
- @tenacity_retry
162
- async def aio_run_workflows(
163
- self,
164
- workflows: list[WorkflowRunDict],
165
- options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
166
- ) -> list[WorkflowRunRef]:
167
- ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
168
- ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
169
- ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
170
- if not self.pooled_workflow_listener:
171
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
172
-
173
- return await asyncio.to_thread(self.run_workflows, workflows, options)
174
-
175
159
  @tenacity_retry
176
160
  async def aio_put_workflow(
177
161
  self,
178
162
  name: str,
179
- workflow: workflow_protos.CreateWorkflowVersionOpts,
180
- overrides: workflow_protos.CreateWorkflowVersionOpts | None = None,
181
- ) -> workflow_protos.WorkflowVersion:
163
+ workflow: workflow_protos.CreateWorkflowVersionRequest,
164
+ overrides: workflow_protos.CreateWorkflowVersionRequest | None = None,
165
+ ) -> workflow_protos.CreateWorkflowVersionResponse:
182
166
  ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
183
167
  ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
184
168
  ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
@@ -209,7 +193,7 @@ class AdminClient:
209
193
  schedules: list[Union[datetime, timestamp_pb2.Timestamp]],
210
194
  input: JSONSerializableMapping = {},
211
195
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
212
- ) -> workflow_protos.WorkflowVersion:
196
+ ) -> v0_workflow_protos.WorkflowVersion:
213
197
  ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
214
198
  ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
215
199
  ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
@@ -224,18 +208,19 @@ class AdminClient:
224
208
  def put_workflow(
225
209
  self,
226
210
  name: str,
227
- workflow: workflow_protos.CreateWorkflowVersionOpts,
228
- overrides: workflow_protos.CreateWorkflowVersionOpts | None = None,
229
- ) -> workflow_protos.WorkflowVersion:
211
+ workflow: workflow_protos.CreateWorkflowVersionRequest,
212
+ overrides: workflow_protos.CreateWorkflowVersionRequest | None = None,
213
+ ) -> workflow_protos.CreateWorkflowVersionResponse:
230
214
  opts = self._prepare_put_workflow_request(name, workflow, overrides)
231
215
 
232
- resp: workflow_protos.WorkflowVersion = self.client.PutWorkflow(
233
- opts,
234
- metadata=get_metadata(self.token),
216
+ return cast(
217
+ workflow_protos.CreateWorkflowVersionResponse,
218
+ self.client.PutWorkflow(
219
+ opts,
220
+ metadata=get_metadata(self.token),
221
+ ),
235
222
  )
236
223
 
237
- return resp
238
-
239
224
  @tenacity_retry
240
225
  def put_rate_limit(
241
226
  self,
@@ -246,11 +231,12 @@ class AdminClient:
246
231
  duration_proto = convert_python_enum_to_proto(
247
232
  duration, workflow_protos.RateLimitDuration
248
233
  )
249
- self.client.PutRateLimit(
250
- workflow_protos.PutRateLimitRequest(
234
+
235
+ self.v0_client.PutRateLimit(
236
+ v0_workflow_protos.PutRateLimitRequest(
251
237
  key=key,
252
238
  limit=limit,
253
- duration=maybe_int_to_str(duration_proto),
239
+ duration=duration_proto, # type: ignore[arg-type]
254
240
  ),
255
241
  metadata=get_metadata(self.token),
256
242
  )
@@ -262,7 +248,7 @@ class AdminClient:
262
248
  schedules: list[Union[datetime, timestamp_pb2.Timestamp]],
263
249
  input: JSONSerializableMapping = {},
264
250
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
265
- ) -> workflow_protos.WorkflowVersion:
251
+ ) -> v0_workflow_protos.WorkflowVersion:
266
252
  try:
267
253
  namespace = options.namespace or self.namespace
268
254
 
@@ -274,8 +260,8 @@ class AdminClient:
274
260
  )
275
261
 
276
262
  return cast(
277
- workflow_protos.WorkflowVersion,
278
- self.client.ScheduleWorkflow(
263
+ v0_workflow_protos.WorkflowVersion,
264
+ self.v0_client.ScheduleWorkflow(
279
265
  request,
280
266
  metadata=get_metadata(self.token),
281
267
  ),
@@ -286,6 +272,44 @@ class AdminClient:
286
272
 
287
273
  raise e
288
274
 
275
+ def _create_workflow_run_request(
276
+ self,
277
+ workflow_name: str,
278
+ input: JSONSerializableMapping,
279
+ options: TriggerWorkflowOptions,
280
+ ) -> v0_workflow_protos.TriggerWorkflowRequest:
281
+ workflow_run_id = ctx_workflow_run_id.get()
282
+ step_run_id = ctx_step_run_id.get()
283
+ worker_id = ctx_worker_id.get()
284
+ spawn_index = workflow_spawn_indices[workflow_run_id] if workflow_run_id else 0
285
+
286
+ ## Increment the spawn_index for the parent workflow
287
+ if workflow_run_id:
288
+ workflow_spawn_indices[workflow_run_id] += 1
289
+
290
+ desired_worker_id = (
291
+ (options.desired_worker_id or worker_id) if options.sticky else None
292
+ )
293
+ child_index = (
294
+ options.child_index if options.child_index is not None else spawn_index
295
+ )
296
+
297
+ trigger_options = TriggerWorkflowOptions(
298
+ parent_id=options.parent_id or workflow_run_id,
299
+ parent_step_run_id=options.parent_step_run_id or step_run_id,
300
+ child_key=options.child_key,
301
+ child_index=child_index,
302
+ additional_metadata=options.additional_metadata,
303
+ desired_worker_id=desired_worker_id,
304
+ )
305
+
306
+ namespace = options.namespace or self.namespace
307
+
308
+ if namespace != "" and not workflow_name.startswith(self.namespace):
309
+ workflow_name = f"{namespace}{workflow_name}"
310
+
311
+ return self._prepare_workflow_request(workflow_name, input, trigger_options)
312
+
289
313
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
290
314
  @tenacity_retry
291
315
  def run_workflow(
@@ -294,70 +318,125 @@ class AdminClient:
294
318
  input: JSONSerializableMapping,
295
319
  options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
296
320
  ) -> WorkflowRunRef:
321
+ request = self._create_workflow_run_request(workflow_name, input, options)
322
+
323
+ if not self.pooled_workflow_listener:
324
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
325
+
297
326
  try:
298
- if not self.pooled_workflow_listener:
299
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
327
+ resp = cast(
328
+ v0_workflow_protos.TriggerWorkflowResponse,
329
+ self.v0_client.TriggerWorkflow(
330
+ request,
331
+ metadata=get_metadata(self.token),
332
+ ),
333
+ )
334
+ except (grpc.RpcError, grpc.aio.AioRpcError) as e:
335
+ if e.code() == grpc.StatusCode.ALREADY_EXISTS:
336
+ raise DedupeViolationErr(e.details())
300
337
 
301
- namespace = options.namespace or self.namespace
338
+ return WorkflowRunRef(
339
+ workflow_run_id=resp.workflow_run_id,
340
+ workflow_listener=self.pooled_workflow_listener,
341
+ workflow_run_event_listener=self.listener_client,
342
+ )
302
343
 
303
- if namespace != "" and not workflow_name.startswith(self.namespace):
304
- workflow_name = f"{namespace}{workflow_name}"
344
+ ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
345
+ @tenacity_retry
346
+ async def aio_run_workflow(
347
+ self,
348
+ workflow_name: str,
349
+ input: JSONSerializableMapping,
350
+ options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
351
+ ) -> WorkflowRunRef:
352
+ ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
353
+ ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
354
+ ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
355
+ async with spawn_index_lock:
356
+ request = self._create_workflow_run_request(workflow_name, input, options)
305
357
 
306
- request = self._prepare_workflow_request(workflow_name, input, options)
358
+ if not self.pooled_workflow_listener:
359
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
307
360
 
361
+ try:
308
362
  resp = cast(
309
- workflow_protos.TriggerWorkflowResponse,
310
- self.client.TriggerWorkflow(
363
+ v0_workflow_protos.TriggerWorkflowResponse,
364
+ self.v0_client.TriggerWorkflow(
311
365
  request,
312
366
  metadata=get_metadata(self.token),
313
367
  ),
314
368
  )
315
-
316
- return WorkflowRunRef(
317
- workflow_run_id=resp.workflow_run_id,
318
- workflow_listener=self.pooled_workflow_listener,
319
- workflow_run_event_listener=self.listener_client,
320
- )
321
369
  except (grpc.RpcError, grpc.aio.AioRpcError) as e:
322
370
  if e.code() == grpc.StatusCode.ALREADY_EXISTS:
323
371
  raise DedupeViolationErr(e.details())
324
372
 
325
373
  raise e
326
374
 
327
- def _prepare_workflow_run_request(
328
- self, workflow: WorkflowRunDict, options: TriggerWorkflowOptions
329
- ) -> workflow_protos.TriggerWorkflowRequest:
330
- workflow_name = workflow.workflow_name
331
- input_data = workflow.input
332
- options = workflow.options
333
-
334
- namespace = options.namespace or self.namespace
335
-
336
- if namespace != "" and not workflow_name.startswith(self.namespace):
337
- workflow_name = f"{namespace}{workflow_name}"
338
-
339
- return self._prepare_workflow_request(workflow_name, input_data, options)
375
+ return WorkflowRunRef(
376
+ workflow_run_id=resp.workflow_run_id,
377
+ workflow_listener=self.pooled_workflow_listener,
378
+ workflow_run_event_listener=self.listener_client,
379
+ )
340
380
 
341
381
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
342
382
  @tenacity_retry
343
383
  def run_workflows(
344
384
  self,
345
- workflows: list[WorkflowRunDict],
346
- options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
385
+ workflows: list[WorkflowRunTriggerConfig],
347
386
  ) -> list[WorkflowRunRef]:
348
387
  if not self.pooled_workflow_listener:
349
388
  self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
350
389
 
351
- bulk_request = workflow_protos.BulkTriggerWorkflowRequest(
390
+ bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
352
391
  workflows=[
353
- self._prepare_workflow_run_request(workflow, options)
392
+ self._create_workflow_run_request(
393
+ workflow.workflow_name, workflow.input, workflow.options
394
+ )
354
395
  for workflow in workflows
355
396
  ]
356
397
  )
357
398
 
358
399
  resp = cast(
359
- workflow_protos.BulkTriggerWorkflowResponse,
360
- self.client.BulkTriggerWorkflow(
400
+ v0_workflow_protos.BulkTriggerWorkflowResponse,
401
+ self.v0_client.BulkTriggerWorkflow(
402
+ bulk_request,
403
+ metadata=get_metadata(self.token),
404
+ ),
405
+ )
406
+
407
+ return [
408
+ WorkflowRunRef(
409
+ workflow_run_id=workflow_run_id,
410
+ workflow_listener=self.pooled_workflow_listener,
411
+ workflow_run_event_listener=self.listener_client,
412
+ )
413
+ for workflow_run_id in resp.workflow_run_ids
414
+ ]
415
+
416
+ @tenacity_retry
417
+ async def aio_run_workflows(
418
+ self,
419
+ workflows: list[WorkflowRunTriggerConfig],
420
+ ) -> list[WorkflowRunRef]:
421
+ ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
422
+ ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
423
+ ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
424
+ if not self.pooled_workflow_listener:
425
+ self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
426
+
427
+ async with spawn_index_lock:
428
+ bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
429
+ workflows=[
430
+ self._create_workflow_run_request(
431
+ workflow.workflow_name, workflow.input, workflow.options
432
+ )
433
+ for workflow in workflows
434
+ ]
435
+ )
436
+
437
+ resp = cast(
438
+ v0_workflow_protos.BulkTriggerWorkflowResponse,
439
+ self.v0_client.BulkTriggerWorkflow(
361
440
  bulk_request,
362
441
  metadata=get_metadata(self.token),
363
442
  ),