hatchet-sdk 0.42.3__py3-none-any.whl → 0.42.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

@@ -41,7 +41,7 @@ def new_admin(config: ClientConfig):
41
41
  return AdminClient(config)
42
42
 
43
43
 
44
- class ScheduleTriggerWorkflowOptions(TypedDict):
44
+ class ScheduleTriggerWorkflowOptions(TypedDict, total=False):
45
45
  parent_id: Optional[str]
46
46
  parent_step_run_id: Optional[str]
47
47
  child_index: Optional[int]
@@ -49,25 +49,25 @@ class ScheduleTriggerWorkflowOptions(TypedDict):
49
49
  namespace: Optional[str]
50
50
 
51
51
 
52
- class ChildTriggerWorkflowOptions(TypedDict):
52
+ class ChildTriggerWorkflowOptions(TypedDict, total=False):
53
53
  additional_metadata: Dict[str, str] | None = None
54
54
  sticky: bool | None = None
55
55
 
56
56
 
57
- class ChildWorkflowRunDict(TypedDict):
57
+ class ChildWorkflowRunDict(TypedDict, total=False):
58
58
  workflow_name: str
59
59
  input: Any
60
60
  options: ChildTriggerWorkflowOptions
61
61
  key: str | None = None
62
62
 
63
63
 
64
- class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, TypedDict):
64
+ class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, total=False):
65
65
  additional_metadata: Dict[str, str] | None = None
66
66
  desired_worker_id: str | None = None
67
67
  namespace: str | None = None
68
68
 
69
69
 
70
- class WorkflowRunDict(TypedDict):
70
+ class WorkflowRunDict(TypedDict, total=False):
71
71
  workflow_name: str
72
72
  input: Any
73
73
  options: TriggerWorkflowOptions | None
@@ -425,7 +425,7 @@ class AdminClient(AdminClientBase):
425
425
  self,
426
426
  key: str,
427
427
  limit: int,
428
- duration: RateLimitDuration = RateLimitDuration.SECOND,
428
+ duration: Union[RateLimitDuration.Value, str] = RateLimitDuration.SECOND,
429
429
  ):
430
430
  try:
431
431
  self.client.PutRateLimit(
@@ -1,3 +1,5 @@
1
+ from typing import Any, cast
2
+
1
3
  from google.protobuf.timestamp_pb2 import Timestamp
2
4
 
3
5
  from hatchet_sdk.clients.dispatcher.action_listener import (
@@ -31,7 +33,7 @@ from ...metadata import get_metadata
31
33
  DEFAULT_REGISTER_TIMEOUT = 30
32
34
 
33
35
 
34
- def new_dispatcher(config: ClientConfig):
36
+ def new_dispatcher(config: ClientConfig) -> "DispatcherClient":
35
37
  return DispatcherClient(config=config)
36
38
 
37
39
 
@@ -40,10 +42,10 @@ class DispatcherClient:
40
42
 
41
43
  def __init__(self, config: ClientConfig):
42
44
  conn = new_conn(config)
43
- self.client = DispatcherStub(conn)
45
+ self.client = DispatcherStub(conn) # type: ignore[no-untyped-call]
44
46
 
45
47
  aio_conn = new_conn(config, True)
46
- self.aio_client = DispatcherStub(aio_conn)
48
+ self.aio_client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
47
49
  self.token = config.token
48
50
  self.config = config
49
51
 
@@ -67,7 +69,7 @@ class DispatcherClient:
67
69
 
68
70
  async def send_step_action_event(
69
71
  self, action: Action, event_type: StepActionEventType, payload: str
70
- ):
72
+ ) -> Any:
71
73
  try:
72
74
  return await self._try_send_step_action_event(action, event_type, payload)
73
75
  except Exception as e:
@@ -87,7 +89,7 @@ class DispatcherClient:
87
89
  @tenacity_retry
88
90
  async def _try_send_step_action_event(
89
91
  self, action: Action, event_type: StepActionEventType, payload: str
90
- ):
92
+ ) -> Any:
91
93
  eventTimestamp = Timestamp()
92
94
  eventTimestamp.GetCurrentTime()
93
95
 
@@ -103,6 +105,7 @@ class DispatcherClient:
103
105
  eventPayload=payload,
104
106
  )
105
107
 
108
+ ## TODO: What does this return?
106
109
  return await self.aio_client.SendStepActionEvent(
107
110
  event,
108
111
  metadata=get_metadata(self.token),
@@ -110,7 +113,7 @@ class DispatcherClient:
110
113
 
111
114
  async def send_group_key_action_event(
112
115
  self, action: Action, event_type: GroupKeyActionEventType, payload: str
113
- ):
116
+ ) -> Any:
114
117
  eventTimestamp = Timestamp()
115
118
  eventTimestamp.GetCurrentTime()
116
119
 
@@ -124,19 +127,21 @@ class DispatcherClient:
124
127
  eventPayload=payload,
125
128
  )
126
129
 
130
+ ## TODO: What does this return?
127
131
  return await self.aio_client.SendGroupKeyActionEvent(
128
132
  event,
129
133
  metadata=get_metadata(self.token),
130
134
  )
131
135
 
132
- def put_overrides_data(self, data: OverridesData):
133
- response: ActionEventResponse = self.client.PutOverridesData(
134
- data,
135
- metadata=get_metadata(self.token),
136
+ def put_overrides_data(self, data: OverridesData) -> ActionEventResponse:
137
+ return cast(
138
+ ActionEventResponse,
139
+ self.client.PutOverridesData(
140
+ data,
141
+ metadata=get_metadata(self.token),
142
+ ),
136
143
  )
137
144
 
138
- return response
139
-
140
145
  def release_slot(self, step_run_id: str) -> None:
141
146
  self.client.ReleaseSlot(
142
147
  ReleaseSlotRequest(stepRunId=step_run_id),
@@ -154,7 +159,9 @@ class DispatcherClient:
154
159
  metadata=get_metadata(self.token),
155
160
  )
156
161
 
157
- def upsert_worker_labels(self, worker_id: str, labels: dict[str, str | int]):
162
+ def upsert_worker_labels(
163
+ self, worker_id: str | None, labels: dict[str, str | int]
164
+ ) -> None:
158
165
  worker_labels = {}
159
166
 
160
167
  for key, value in labels.items():
@@ -171,9 +178,9 @@ class DispatcherClient:
171
178
 
172
179
  async def async_upsert_worker_labels(
173
180
  self,
174
- worker_id: str,
181
+ worker_id: str | None,
175
182
  labels: dict[str, str | int],
176
- ):
183
+ ) -> None:
177
184
  worker_labels = {}
178
185
 
179
186
  for key, value in labels.items():
@@ -43,16 +43,16 @@ def proto_timestamp_now():
43
43
  return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
44
44
 
45
45
 
46
- class PushEventOptions(TypedDict):
46
+ class PushEventOptions(TypedDict, total=False):
47
47
  additional_metadata: Dict[str, str] | None = None
48
48
  namespace: str | None = None
49
49
 
50
50
 
51
- class BulkPushEventOptions(TypedDict):
51
+ class BulkPushEventOptions(TypedDict, total=False):
52
52
  namespace: str | None = None
53
53
 
54
54
 
55
- class BulkPushEventWithMetadata(TypedDict):
55
+ class BulkPushEventWithMetadata(TypedDict, total=False):
56
56
  key: str
57
57
  payload: Any
58
58
  additional_metadata: Optional[Dict[str, Any]] # Optional metadata
@@ -48,6 +48,266 @@ class DefaultApi:
48
48
  api_client = ApiClient.get_default()
49
49
  self.api_client = api_client
50
50
 
51
+ @validate_call
52
+ async def monitoring_post_run_probe(
53
+ self,
54
+ tenant: Annotated[
55
+ str,
56
+ Field(
57
+ min_length=36, strict=True, max_length=36, description="The tenant id"
58
+ ),
59
+ ],
60
+ _request_timeout: Union[
61
+ None,
62
+ Annotated[StrictFloat, Field(gt=0)],
63
+ Tuple[
64
+ Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]
65
+ ],
66
+ ] = None,
67
+ _request_auth: Optional[Dict[StrictStr, Any]] = None,
68
+ _content_type: Optional[StrictStr] = None,
69
+ _headers: Optional[Dict[StrictStr, Any]] = None,
70
+ _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0,
71
+ ) -> None:
72
+ """Detailed Health Probe For the Instance
73
+
74
+ Triggers a workflow to check the status of the instance
75
+
76
+ :param tenant: The tenant id (required)
77
+ :type tenant: str
78
+ :param _request_timeout: timeout setting for this request. If one
79
+ number provided, it will be total request
80
+ timeout. It can also be a pair (tuple) of
81
+ (connection, read) timeouts.
82
+ :type _request_timeout: int, tuple(int, int), optional
83
+ :param _request_auth: set to override the auth_settings for an a single
84
+ request; this effectively ignores the
85
+ authentication in the spec for a single request.
86
+ :type _request_auth: dict, optional
87
+ :param _content_type: force content-type for the request.
88
+ :type _content_type: str, Optional
89
+ :param _headers: set to override the headers for a single
90
+ request; this effectively ignores the headers
91
+ in the spec for a single request.
92
+ :type _headers: dict, optional
93
+ :param _host_index: set to override the host_index for a single
94
+ request; this effectively ignores the host_index
95
+ in the spec for a single request.
96
+ :type _host_index: int, optional
97
+ :return: Returns the result object.
98
+ """ # noqa: E501
99
+
100
+ _param = self._monitoring_post_run_probe_serialize(
101
+ tenant=tenant,
102
+ _request_auth=_request_auth,
103
+ _content_type=_content_type,
104
+ _headers=_headers,
105
+ _host_index=_host_index,
106
+ )
107
+
108
+ _response_types_map: Dict[str, Optional[str]] = {
109
+ "200": None,
110
+ "403": "APIErrors",
111
+ }
112
+ response_data = await self.api_client.call_api(
113
+ *_param, _request_timeout=_request_timeout
114
+ )
115
+ await response_data.read()
116
+ return self.api_client.response_deserialize(
117
+ response_data=response_data,
118
+ response_types_map=_response_types_map,
119
+ ).data
120
+
121
+ @validate_call
122
+ async def monitoring_post_run_probe_with_http_info(
123
+ self,
124
+ tenant: Annotated[
125
+ str,
126
+ Field(
127
+ min_length=36, strict=True, max_length=36, description="The tenant id"
128
+ ),
129
+ ],
130
+ _request_timeout: Union[
131
+ None,
132
+ Annotated[StrictFloat, Field(gt=0)],
133
+ Tuple[
134
+ Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]
135
+ ],
136
+ ] = None,
137
+ _request_auth: Optional[Dict[StrictStr, Any]] = None,
138
+ _content_type: Optional[StrictStr] = None,
139
+ _headers: Optional[Dict[StrictStr, Any]] = None,
140
+ _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0,
141
+ ) -> ApiResponse[None]:
142
+ """Detailed Health Probe For the Instance
143
+
144
+ Triggers a workflow to check the status of the instance
145
+
146
+ :param tenant: The tenant id (required)
147
+ :type tenant: str
148
+ :param _request_timeout: timeout setting for this request. If one
149
+ number provided, it will be total request
150
+ timeout. It can also be a pair (tuple) of
151
+ (connection, read) timeouts.
152
+ :type _request_timeout: int, tuple(int, int), optional
153
+ :param _request_auth: set to override the auth_settings for an a single
154
+ request; this effectively ignores the
155
+ authentication in the spec for a single request.
156
+ :type _request_auth: dict, optional
157
+ :param _content_type: force content-type for the request.
158
+ :type _content_type: str, Optional
159
+ :param _headers: set to override the headers for a single
160
+ request; this effectively ignores the headers
161
+ in the spec for a single request.
162
+ :type _headers: dict, optional
163
+ :param _host_index: set to override the host_index for a single
164
+ request; this effectively ignores the host_index
165
+ in the spec for a single request.
166
+ :type _host_index: int, optional
167
+ :return: Returns the result object.
168
+ """ # noqa: E501
169
+
170
+ _param = self._monitoring_post_run_probe_serialize(
171
+ tenant=tenant,
172
+ _request_auth=_request_auth,
173
+ _content_type=_content_type,
174
+ _headers=_headers,
175
+ _host_index=_host_index,
176
+ )
177
+
178
+ _response_types_map: Dict[str, Optional[str]] = {
179
+ "200": None,
180
+ "403": "APIErrors",
181
+ }
182
+ response_data = await self.api_client.call_api(
183
+ *_param, _request_timeout=_request_timeout
184
+ )
185
+ await response_data.read()
186
+ return self.api_client.response_deserialize(
187
+ response_data=response_data,
188
+ response_types_map=_response_types_map,
189
+ )
190
+
191
+ @validate_call
192
+ async def monitoring_post_run_probe_without_preload_content(
193
+ self,
194
+ tenant: Annotated[
195
+ str,
196
+ Field(
197
+ min_length=36, strict=True, max_length=36, description="The tenant id"
198
+ ),
199
+ ],
200
+ _request_timeout: Union[
201
+ None,
202
+ Annotated[StrictFloat, Field(gt=0)],
203
+ Tuple[
204
+ Annotated[StrictFloat, Field(gt=0)], Annotated[StrictFloat, Field(gt=0)]
205
+ ],
206
+ ] = None,
207
+ _request_auth: Optional[Dict[StrictStr, Any]] = None,
208
+ _content_type: Optional[StrictStr] = None,
209
+ _headers: Optional[Dict[StrictStr, Any]] = None,
210
+ _host_index: Annotated[StrictInt, Field(ge=0, le=0)] = 0,
211
+ ) -> RESTResponseType:
212
+ """Detailed Health Probe For the Instance
213
+
214
+ Triggers a workflow to check the status of the instance
215
+
216
+ :param tenant: The tenant id (required)
217
+ :type tenant: str
218
+ :param _request_timeout: timeout setting for this request. If one
219
+ number provided, it will be total request
220
+ timeout. It can also be a pair (tuple) of
221
+ (connection, read) timeouts.
222
+ :type _request_timeout: int, tuple(int, int), optional
223
+ :param _request_auth: set to override the auth_settings for an a single
224
+ request; this effectively ignores the
225
+ authentication in the spec for a single request.
226
+ :type _request_auth: dict, optional
227
+ :param _content_type: force content-type for the request.
228
+ :type _content_type: str, Optional
229
+ :param _headers: set to override the headers for a single
230
+ request; this effectively ignores the headers
231
+ in the spec for a single request.
232
+ :type _headers: dict, optional
233
+ :param _host_index: set to override the host_index for a single
234
+ request; this effectively ignores the host_index
235
+ in the spec for a single request.
236
+ :type _host_index: int, optional
237
+ :return: Returns the result object.
238
+ """ # noqa: E501
239
+
240
+ _param = self._monitoring_post_run_probe_serialize(
241
+ tenant=tenant,
242
+ _request_auth=_request_auth,
243
+ _content_type=_content_type,
244
+ _headers=_headers,
245
+ _host_index=_host_index,
246
+ )
247
+
248
+ _response_types_map: Dict[str, Optional[str]] = {
249
+ "200": None,
250
+ "403": "APIErrors",
251
+ }
252
+ response_data = await self.api_client.call_api(
253
+ *_param, _request_timeout=_request_timeout
254
+ )
255
+ return response_data.response
256
+
257
+ def _monitoring_post_run_probe_serialize(
258
+ self,
259
+ tenant,
260
+ _request_auth,
261
+ _content_type,
262
+ _headers,
263
+ _host_index,
264
+ ) -> RequestSerialized:
265
+
266
+ _host = None
267
+
268
+ _collection_formats: Dict[str, str] = {}
269
+
270
+ _path_params: Dict[str, str] = {}
271
+ _query_params: List[Tuple[str, str]] = []
272
+ _header_params: Dict[str, Optional[str]] = _headers or {}
273
+ _form_params: List[Tuple[str, str]] = []
274
+ _files: Dict[
275
+ str, Union[str, bytes, List[str], List[bytes], List[Tuple[str, bytes]]]
276
+ ] = {}
277
+ _body_params: Optional[bytes] = None
278
+
279
+ # process the path parameters
280
+ if tenant is not None:
281
+ _path_params["tenant"] = tenant
282
+ # process the query parameters
283
+ # process the header parameters
284
+ # process the form parameters
285
+ # process the body parameter
286
+
287
+ # set the HTTP header `Accept`
288
+ if "Accept" not in _header_params:
289
+ _header_params["Accept"] = self.api_client.select_header_accept(
290
+ ["application/json"]
291
+ )
292
+
293
+ # authentication setting
294
+ _auth_settings: List[str] = ["cookieAuth", "bearerAuth"]
295
+
296
+ return self.api_client.param_serialize(
297
+ method="POST",
298
+ resource_path="/api/v1/monitoring/{tenant}/probe",
299
+ path_params=_path_params,
300
+ query_params=_query_params,
301
+ header_params=_header_params,
302
+ body=_body_params,
303
+ post_params=_form_params,
304
+ files=_files,
305
+ auth_settings=_auth_settings,
306
+ collection_formats=_collection_formats,
307
+ _host=_host,
308
+ _request_auth=_request_auth,
309
+ )
310
+
51
311
  @validate_call
52
312
  async def tenant_invite_delete(
53
313
  self,
@@ -13,8 +13,8 @@ from hatchet_sdk.clients.rest_client import RestApi
13
13
  from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
14
14
  from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
15
15
  from hatchet_sdk.context.worker_context import WorkerContext
16
- from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData # type: ignore
17
- from hatchet_sdk.contracts.workflows_pb2 import ( # type: ignore[attr-defined]
16
+ from hatchet_sdk.contracts.dispatcher_pb2 import OverridesData
17
+ from hatchet_sdk.contracts.workflows_pb2 import (
18
18
  BulkTriggerWorkflowRequest,
19
19
  TriggerWorkflowRequest,
20
20
  )
@@ -69,7 +69,7 @@ class BaseContext:
69
69
  meta = options["additional_metadata"]
70
70
 
71
71
  ## TODO: Pydantic here to simplify this
72
- trigger_options: TriggerWorkflowOptions = { # type: ignore[typeddict-item]
72
+ trigger_options: TriggerWorkflowOptions = {
73
73
  "parent_id": workflow_run_id,
74
74
  "parent_step_run_id": step_run_id,
75
75
  "child_key": key,
@@ -149,8 +149,7 @@ class ContextAioImpl(BaseContext):
149
149
  key = child_workflow_run.get("key")
150
150
  options = child_workflow_run.get("options", {})
151
151
 
152
- ## TODO: figure out why this is failing
153
- trigger_options = self._prepare_workflow_options(key, options, worker_id) # type: ignore[arg-type]
152
+ trigger_options = self._prepare_workflow_options(key, options, worker_id)
154
153
 
155
154
  bulk_trigger_workflow_runs.append(
156
155
  WorkflowRunDict(
@@ -238,14 +237,17 @@ class Context(BaseContext):
238
237
  self.input = self.data.get("input", {})
239
238
 
240
239
  def step_output(self, step: str) -> dict[str, Any] | BaseModel:
241
- validators = self.validator_registry.get(step)
240
+ workflow_validator = next(
241
+ (v for k, v in self.validator_registry.items() if k.split(":")[-1] == step),
242
+ None,
243
+ )
242
244
 
243
245
  try:
244
246
  parent_step_data = cast(dict[str, Any], self.data["parents"][step])
245
247
  except KeyError:
246
248
  raise ValueError(f"Step output for '{step}' not found")
247
249
 
248
- if validators and (v := validators.step_output):
250
+ if workflow_validator and (v := workflow_validator.step_output):
249
251
  return v.model_validate(parent_step_data)
250
252
 
251
253
  return parent_step_data
@@ -2,7 +2,7 @@ from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
2
2
 
3
3
 
4
4
  class WorkerContext:
5
- _worker_id: str = None
5
+ _worker_id: str | None = None
6
6
  _registered_workflow_names: list[str] = []
7
7
  _labels: dict[str, str | int] = {}
8
8
 
@@ -10,18 +10,18 @@ class WorkerContext:
10
10
  self._labels = labels
11
11
  self.client = client
12
12
 
13
- def labels(self):
13
+ def labels(self) -> dict[str, str | int]:
14
14
  return self._labels
15
15
 
16
- def upsert_labels(self, labels: dict[str, str | int]):
16
+ def upsert_labels(self, labels: dict[str, str | int]) -> None:
17
17
  self.client.upsert_worker_labels(self._worker_id, labels)
18
18
  self._labels.update(labels)
19
19
 
20
- async def async_upsert_labels(self, labels: dict[str, str | int]):
20
+ async def async_upsert_labels(self, labels: dict[str, str | int]) -> None:
21
21
  await self.client.async_upsert_worker_labels(self._worker_id, labels)
22
22
  self._labels.update(labels)
23
23
 
24
- def id(self) -> str:
24
+ def id(self) -> str | None:
25
25
  return self._worker_id
26
26
 
27
27
  # def has_workflow(self, workflow_name: str):
@@ -1,12 +1,22 @@
1
1
  # -*- coding: utf-8 -*-
2
2
  # Generated by the protocol buffer compiler. DO NOT EDIT!
3
+ # NO CHECKED-IN PROTOBUF GENCODE
3
4
  # source: dispatcher.proto
4
- # Protobuf Python Version: 4.25.1
5
+ # Protobuf Python Version: 5.28.1
5
6
  """Generated protocol buffer code."""
6
7
  from google.protobuf import descriptor as _descriptor
7
8
  from google.protobuf import descriptor_pool as _descriptor_pool
9
+ from google.protobuf import runtime_version as _runtime_version
8
10
  from google.protobuf import symbol_database as _symbol_database
9
11
  from google.protobuf.internal import builder as _builder
12
+ _runtime_version.ValidateProtobufRuntimeVersion(
13
+ _runtime_version.Domain.PUBLIC,
14
+ 5,
15
+ 28,
16
+ 1,
17
+ '',
18
+ 'dispatcher.proto'
19
+ )
10
20
  # @@protoc_insertion_point(imports)
11
21
 
12
22
  _sym_db = _symbol_database.Default()
@@ -20,12 +30,12 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10\x64ispatcher.
20
30
  _globals = globals()
21
31
  _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
22
32
  _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'dispatcher_pb2', _globals)
23
- if _descriptor._USE_C_DESCRIPTORS == False:
24
- _globals['DESCRIPTOR']._options = None
33
+ if not _descriptor._USE_C_DESCRIPTORS:
34
+ _globals['DESCRIPTOR']._loaded_options = None
25
35
  _globals['DESCRIPTOR']._serialized_options = b'ZEgithub.com/hatchet-dev/hatchet/internal/services/dispatcher/contracts'
26
- _globals['_WORKERREGISTERREQUEST_LABELSENTRY']._options = None
36
+ _globals['_WORKERREGISTERREQUEST_LABELSENTRY']._loaded_options = None
27
37
  _globals['_WORKERREGISTERREQUEST_LABELSENTRY']._serialized_options = b'8\001'
28
- _globals['_UPSERTWORKERLABELSREQUEST_LABELSENTRY']._options = None
38
+ _globals['_UPSERTWORKERLABELSREQUEST_LABELSENTRY']._loaded_options = None
29
39
  _globals['_UPSERTWORKERLABELSREQUEST_LABELSENTRY']._serialized_options = b'8\001'
30
40
  _globals['_SDKS']._serialized_start=3484
31
41
  _globals['_SDKS']._serialized_end=3539