hatchet-sdk 1.12.3__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (48) hide show
  1. hatchet_sdk/__init__.py +46 -40
  2. hatchet_sdk/clients/admin.py +18 -23
  3. hatchet_sdk/clients/dispatcher/action_listener.py +4 -3
  4. hatchet_sdk/clients/dispatcher/dispatcher.py +1 -4
  5. hatchet_sdk/clients/event_ts.py +2 -1
  6. hatchet_sdk/clients/events.py +16 -12
  7. hatchet_sdk/clients/listeners/durable_event_listener.py +4 -2
  8. hatchet_sdk/clients/listeners/pooled_listener.py +2 -2
  9. hatchet_sdk/clients/listeners/run_event_listener.py +7 -8
  10. hatchet_sdk/clients/listeners/workflow_listener.py +14 -6
  11. hatchet_sdk/clients/rest/api_response.py +3 -2
  12. hatchet_sdk/clients/rest/tenacity_utils.py +6 -8
  13. hatchet_sdk/config.py +2 -0
  14. hatchet_sdk/connection.py +10 -4
  15. hatchet_sdk/context/context.py +170 -46
  16. hatchet_sdk/context/worker_context.py +4 -7
  17. hatchet_sdk/contracts/dispatcher_pb2.py +38 -38
  18. hatchet_sdk/contracts/dispatcher_pb2.pyi +4 -2
  19. hatchet_sdk/contracts/events_pb2.py +13 -13
  20. hatchet_sdk/contracts/events_pb2.pyi +4 -2
  21. hatchet_sdk/contracts/v1/workflows_pb2.py +1 -1
  22. hatchet_sdk/contracts/v1/workflows_pb2.pyi +2 -2
  23. hatchet_sdk/exceptions.py +99 -1
  24. hatchet_sdk/features/cron.py +2 -2
  25. hatchet_sdk/features/filters.py +3 -3
  26. hatchet_sdk/features/runs.py +4 -4
  27. hatchet_sdk/features/scheduled.py +8 -9
  28. hatchet_sdk/hatchet.py +65 -64
  29. hatchet_sdk/opentelemetry/instrumentor.py +20 -20
  30. hatchet_sdk/runnables/action.py +1 -2
  31. hatchet_sdk/runnables/contextvars.py +19 -0
  32. hatchet_sdk/runnables/task.py +37 -29
  33. hatchet_sdk/runnables/types.py +9 -8
  34. hatchet_sdk/runnables/workflow.py +57 -42
  35. hatchet_sdk/utils/proto_enums.py +4 -4
  36. hatchet_sdk/utils/timedelta_to_expression.py +2 -3
  37. hatchet_sdk/utils/typing.py +11 -17
  38. hatchet_sdk/waits.py +6 -5
  39. hatchet_sdk/worker/action_listener_process.py +33 -13
  40. hatchet_sdk/worker/runner/run_loop_manager.py +15 -11
  41. hatchet_sdk/worker/runner/runner.py +102 -92
  42. hatchet_sdk/worker/runner/utils/capture_logs.py +72 -31
  43. hatchet_sdk/worker/worker.py +29 -25
  44. hatchet_sdk/workflow_run.py +4 -2
  45. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/METADATA +1 -1
  46. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/RECORD +48 -48
  47. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/WHEEL +0 -0
  48. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/entry_points.txt +0 -0
hatchet_sdk/__init__.py CHANGED
@@ -1,5 +1,4 @@
1
1
  from hatchet_sdk.clients.admin import (
2
- DedupeViolationErr,
3
2
  ScheduleTriggerWorkflowOptions,
4
3
  TriggerWorkflowOptions,
5
4
  )
@@ -138,6 +137,11 @@ from hatchet_sdk.contracts.workflows_pb2 import (
138
137
  RateLimitDuration,
139
138
  WorkerLabelComparator,
140
139
  )
140
+ from hatchet_sdk.exceptions import (
141
+ DedupeViolationError,
142
+ FailedTaskRunExceptionGroup,
143
+ TaskRunError,
144
+ )
141
145
  from hatchet_sdk.features.runs import BulkCancelReplayOpts, RunFilter
142
146
  from hatchet_sdk.hatchet import Hatchet
143
147
  from hatchet_sdk.runnables.task import Task
@@ -162,7 +166,6 @@ from hatchet_sdk.waits import (
162
166
  from hatchet_sdk.worker.worker import Worker, WorkerStartOptions, WorkerStatus
163
167
 
164
168
  __all__ = [
165
- "AcceptInviteRequest",
166
169
  "APIError",
167
170
  "APIErrors",
168
171
  "APIMeta",
@@ -170,11 +173,24 @@ __all__ = [
170
173
  "APIMetaIntegration",
171
174
  "APIResourceMeta",
172
175
  "APIToken",
176
+ "AcceptInviteRequest",
177
+ "BulkCancelReplayOpts",
178
+ "ClientConfig",
179
+ "ClientTLSConfig",
180
+ "ConcurrencyExpression",
181
+ "ConcurrencyLimitStrategy",
182
+ "Condition",
183
+ "Context",
173
184
  "CreateAPITokenRequest",
174
185
  "CreateAPITokenResponse",
175
186
  "CreatePullRequestFromStepRun",
176
187
  "CreateTenantInviteRequest",
177
188
  "CreateTenantRequest",
189
+ "CreateWorkflowVersionOpts",
190
+ "DedupeViolationError",
191
+ "DefaultFilter",
192
+ "DurableContext",
193
+ "EmptyModel",
178
194
  "Event",
179
195
  "EventData",
180
196
  "EventKeyList",
@@ -182,10 +198,12 @@ __all__ = [
182
198
  "EventOrderByDirection",
183
199
  "EventOrderByField",
184
200
  "EventWorkflowRunSummary",
201
+ "FailedTaskRunExceptionGroup",
185
202
  "GetStepRunDiffResponse",
186
203
  "GithubAppInstallation",
187
204
  "GithubBranch",
188
205
  "GithubRepo",
206
+ "Hatchet",
189
207
  "Job",
190
208
  "JobRun",
191
209
  "JobRunStatus",
@@ -198,15 +216,30 @@ __all__ = [
198
216
  "LogLineList",
199
217
  "LogLineOrderByDirection",
200
218
  "LogLineOrderByField",
219
+ "OTelAttribute",
220
+ "OpenTelemetryConfig",
221
+ "OrGroup",
201
222
  "PaginationResponse",
223
+ "ParentCondition",
202
224
  "PullRequest",
203
225
  "PullRequestState",
226
+ "PushEventOptions",
227
+ "RateLimitDuration",
228
+ "RegisterDurableEventRequest",
204
229
  "RejectInviteRequest",
205
230
  "ReplayEventRequest",
206
231
  "RerunStepRunRequest",
232
+ "RunFilter",
233
+ "ScheduleTriggerWorkflowOptions",
234
+ "SleepCondition",
207
235
  "StepRun",
208
236
  "StepRunDiff",
237
+ "StepRunEventType",
209
238
  "StepRunStatus",
239
+ "StickyStrategy",
240
+ "Task",
241
+ "TaskDefaults",
242
+ "TaskRunError",
210
243
  "Tenant",
211
244
  "TenantInvite",
212
245
  "TenantInviteList",
@@ -214,20 +247,30 @@ __all__ = [
214
247
  "TenantMember",
215
248
  "TenantMemberList",
216
249
  "TenantMemberRole",
250
+ "TriggerWorkflowOptions",
217
251
  "TriggerWorkflowRunRequest",
218
252
  "UpdateTenantInviteRequest",
219
253
  "User",
254
+ "UserEventCondition",
220
255
  "UserLoginRequest",
221
256
  "UserRegisterRequest",
222
257
  "UserTenantMembershipsList",
223
258
  "UserTenantPublic",
259
+ "V1TaskStatus",
260
+ "Worker",
224
261
  "Worker",
262
+ "WorkerContext",
225
263
  "WorkerLabelComparator",
226
264
  "WorkerList",
265
+ "WorkerStartOptions",
266
+ "WorkerStatus",
267
+ "Workflow",
227
268
  "Workflow",
269
+ "WorkflowConfig",
228
270
  "WorkflowDeploymentConfig",
229
271
  "WorkflowList",
230
272
  "WorkflowRun",
273
+ "WorkflowRunEventType",
231
274
  "WorkflowRunList",
232
275
  "WorkflowRunStatus",
233
276
  "WorkflowRunTriggeredBy",
@@ -238,43 +281,6 @@ __all__ = [
238
281
  "WorkflowVersion",
239
282
  "WorkflowVersionDefinition",
240
283
  "WorkflowVersionMeta",
241
- "ConcurrencyLimitStrategy",
242
- "CreateWorkflowVersionOpts",
243
- "RateLimitDuration",
244
- "StickyStrategy",
245
- "DedupeViolationErr",
246
- "ScheduleTriggerWorkflowOptions",
247
- "TriggerWorkflowOptions",
248
- "PushEventOptions",
249
- "StepRunEventType",
250
- "WorkflowRunEventType",
251
- "Context",
252
- "WorkerContext",
253
- "ClientConfig",
254
- "Hatchet",
255
- "workflow",
256
- "Worker",
257
- "WorkerStartOptions",
258
- "WorkerStatus",
259
- "ConcurrencyExpression",
260
- "Workflow",
261
- "WorkflowConfig",
262
- "Task",
263
- "EmptyModel",
264
- "Condition",
265
- "OrGroup",
266
284
  "or_",
267
- "SleepCondition",
268
- "UserEventCondition",
269
- "ParentCondition",
270
- "DurableContext",
271
- "RegisterDurableEventRequest",
272
- "TaskDefaults",
273
- "BulkCancelReplayOpts",
274
- "RunFilter",
275
- "V1TaskStatus",
276
- "OTelAttribute",
277
- "OpenTelemetryConfig",
278
- "ClientTLSConfig",
279
- "DefaultFilter",
285
+ "workflow",
280
286
  ]
@@ -1,7 +1,8 @@
1
1
  import asyncio
2
2
  import json
3
+ from collections.abc import Generator
3
4
  from datetime import datetime
4
- from typing import Generator, TypeVar, Union, cast
5
+ from typing import TypeVar, cast
5
6
 
6
7
  import grpc
7
8
  from google.protobuf import timestamp_pb2
@@ -16,6 +17,7 @@ from hatchet_sdk.contracts import workflows_pb2 as v0_workflow_protos
16
17
  from hatchet_sdk.contracts.v1 import workflows_pb2 as workflow_protos
17
18
  from hatchet_sdk.contracts.v1.workflows_pb2_grpc import AdminServiceStub
18
19
  from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub
20
+ from hatchet_sdk.exceptions import DedupeViolationError
19
21
  from hatchet_sdk.features.runs import RunsClient
20
22
  from hatchet_sdk.metadata import get_metadata
21
23
  from hatchet_sdk.rate_limit import RateLimitDuration
@@ -59,12 +61,6 @@ class WorkflowRunTriggerConfig(BaseModel):
59
61
  key: str | None = None
60
62
 
61
63
 
62
- class DedupeViolationErr(Exception):
63
- """Raised by the Hatchet library to indicate that a workflow has already been run with this deduplication value."""
64
-
65
- pass
66
-
67
-
68
64
  class AdminClient:
69
65
  def __init__(
70
66
  self,
@@ -113,7 +109,7 @@ class AdminClient:
113
109
  try:
114
110
  return json.dumps(v).encode("utf-8")
115
111
  except json.JSONDecodeError as e:
116
- raise ValueError(f"Error encoding payload: {e}")
112
+ raise ValueError("Error encoding payload") from e
117
113
 
118
114
  def _prepare_workflow_request(
119
115
  self,
@@ -124,7 +120,7 @@ class AdminClient:
124
120
  try:
125
121
  payload_data = json.dumps(input)
126
122
  except json.JSONDecodeError as e:
127
- raise ValueError(f"Error encoding payload: {e}")
123
+ raise ValueError("Error encoding payload") from e
128
124
 
129
125
  _options = self.TriggerWorkflowRequest.model_validate(options.model_dump())
130
126
 
@@ -148,18 +144,17 @@ class AdminClient:
148
144
  seconds = int(t)
149
145
  nanos = int(t % 1 * 1e9)
150
146
  return timestamp_pb2.Timestamp(seconds=seconds, nanos=nanos)
151
- elif isinstance(schedule, timestamp_pb2.Timestamp):
147
+ if isinstance(schedule, timestamp_pb2.Timestamp):
152
148
  return schedule
153
- else:
154
- raise ValueError(
155
- "Invalid schedule type. Must be datetime or timestamp_pb2.Timestamp."
156
- )
149
+ raise ValueError(
150
+ "Invalid schedule type. Must be datetime or timestamp_pb2.Timestamp."
151
+ )
157
152
 
158
153
  def _prepare_schedule_workflow_request(
159
154
  self,
160
155
  name: str,
161
- schedules: list[Union[datetime, timestamp_pb2.Timestamp]],
162
- input: JSONSerializableMapping = {},
156
+ schedules: list[datetime | timestamp_pb2.Timestamp],
157
+ input: JSONSerializableMapping | None = None,
163
158
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
164
159
  ) -> v0_workflow_protos.ScheduleWorkflowRequest:
165
160
  return v0_workflow_protos.ScheduleWorkflowRequest(
@@ -194,8 +189,8 @@ class AdminClient:
194
189
  async def aio_schedule_workflow(
195
190
  self,
196
191
  name: str,
197
- schedules: list[Union[datetime, timestamp_pb2.Timestamp]],
198
- input: JSONSerializableMapping = {},
192
+ schedules: list[datetime | timestamp_pb2.Timestamp],
193
+ input: JSONSerializableMapping | None = None,
199
194
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
200
195
  ) -> v0_workflow_protos.WorkflowVersion:
201
196
  return await asyncio.to_thread(
@@ -245,8 +240,8 @@ class AdminClient:
245
240
  def schedule_workflow(
246
241
  self,
247
242
  name: str,
248
- schedules: list[Union[datetime, timestamp_pb2.Timestamp]],
249
- input: JSONSerializableMapping = {},
243
+ schedules: list[datetime | timestamp_pb2.Timestamp],
244
+ input: JSONSerializableMapping | None = None,
250
245
  options: ScheduleTriggerWorkflowOptions = ScheduleTriggerWorkflowOptions(),
251
246
  ) -> v0_workflow_protos.WorkflowVersion:
252
247
  try:
@@ -269,7 +264,7 @@ class AdminClient:
269
264
  )
270
265
  except (grpc.RpcError, grpc.aio.AioRpcError) as e:
271
266
  if e.code() == grpc.StatusCode.ALREADY_EXISTS:
272
- raise DedupeViolationErr(e.details())
267
+ raise DedupeViolationError(e.details()) from e
273
268
 
274
269
  raise e
275
270
 
@@ -336,7 +331,7 @@ class AdminClient:
336
331
  )
337
332
  except (grpc.RpcError, grpc.aio.AioRpcError) as e:
338
333
  if e.code() == grpc.StatusCode.ALREADY_EXISTS:
339
- raise DedupeViolationErr(e.details())
334
+ raise DedupeViolationError(e.details()) from e
340
335
  raise e
341
336
 
342
337
  return WorkflowRunRef(
@@ -369,7 +364,7 @@ class AdminClient:
369
364
  )
370
365
  except (grpc.RpcError, grpc.aio.AioRpcError) as e:
371
366
  if e.code() == grpc.StatusCode.ALREADY_EXISTS:
372
- raise DedupeViolationErr(e.details())
367
+ raise DedupeViolationError(e.details()) from e
373
368
 
374
369
  raise e
375
370
 
@@ -1,7 +1,8 @@
1
1
  import asyncio
2
2
  import json
3
3
  import time
4
- from typing import TYPE_CHECKING, AsyncGenerator, cast
4
+ from collections.abc import AsyncGenerator
5
+ from typing import TYPE_CHECKING, cast
5
6
 
6
7
  import grpc
7
8
  import grpc.aio
@@ -302,7 +303,7 @@ class ActionListener:
302
303
  )
303
304
  self.run_heartbeat = False
304
305
  raise Exception("retry_exhausted")
305
- elif self.retries >= 1:
306
+ if self.retries >= 1:
306
307
  # logger.info
307
308
  # if we are retrying, we wait for a bit. this should eventually be replaced with exp backoff + jitter
308
309
  await exp_backoff_sleep(
@@ -369,4 +370,4 @@ class ActionListener:
369
370
 
370
371
  return cast(WorkerUnsubscribeRequest, req)
371
372
  except grpc.RpcError as e:
372
- raise Exception(f"Failed to unsubscribe: {e}")
373
+ raise Exception("Failed to unsubscribe") from e
@@ -93,10 +93,7 @@ class DispatcherClient:
93
93
  )
94
94
  except Exception as e:
95
95
  # for step action events, send a failure event when we cannot send the completed event
96
- if (
97
- event_type == STEP_EVENT_TYPE_COMPLETED
98
- or event_type == STEP_EVENT_TYPE_FAILED
99
- ):
96
+ if event_type in (STEP_EVENT_TYPE_COMPLETED, STEP_EVENT_TYPE_FAILED):
100
97
  await self._try_send_step_action_event(
101
98
  action,
102
99
  STEP_EVENT_TYPE_FAILED,
@@ -1,5 +1,6 @@
1
1
  import asyncio
2
- from typing import Callable, Generic, TypeVar, cast, overload
2
+ from collections.abc import Callable
3
+ from typing import Generic, TypeVar, cast, overload
3
4
 
4
5
  import grpc.aio
5
6
  from grpc._cython import cygrpc # type: ignore[attr-defined]
@@ -1,7 +1,7 @@
1
1
  import asyncio
2
2
  import datetime
3
3
  import json
4
- from typing import List, cast
4
+ from typing import cast
5
5
 
6
6
  from google.protobuf import timestamp_pb2
7
7
  from pydantic import BaseModel, Field
@@ -88,7 +88,7 @@ class EventClient(BaseRestClient):
88
88
  self,
89
89
  events: list[BulkPushEventWithMetadata],
90
90
  options: BulkPushEventOptions = BulkPushEventOptions(),
91
- ) -> List[Event]:
91
+ ) -> list[Event]:
92
92
  return await asyncio.to_thread(self.bulk_push, events=events, options=options)
93
93
 
94
94
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@@ -105,12 +105,12 @@ class EventClient(BaseRestClient):
105
105
  try:
106
106
  meta_bytes = json.dumps(options.additional_metadata)
107
107
  except Exception as e:
108
- raise ValueError(f"Error encoding meta: {e}")
108
+ raise ValueError("Error encoding meta") from e
109
109
 
110
110
  try:
111
111
  payload_str = json.dumps(payload)
112
112
  except (TypeError, ValueError) as e:
113
- raise ValueError(f"Error encoding payload: {e}")
113
+ raise ValueError("Error encoding payload") from e
114
114
 
115
115
  request = PushEventRequest(
116
116
  key=namespaced_event_key,
@@ -139,12 +139,12 @@ class EventClient(BaseRestClient):
139
139
  try:
140
140
  meta_str = json.dumps(meta)
141
141
  except Exception as e:
142
- raise ValueError(f"Error encoding meta: {e}")
142
+ raise ValueError("Error encoding meta") from e
143
143
 
144
144
  try:
145
145
  serialized_payload = json.dumps(payload)
146
146
  except (TypeError, ValueError) as e:
147
- raise ValueError(f"Error serializing payload: {e}")
147
+ raise ValueError("Error serializing payload") from e
148
148
 
149
149
  return PushEventRequest(
150
150
  key=event_key,
@@ -159,9 +159,9 @@ class EventClient(BaseRestClient):
159
159
  @tenacity_retry
160
160
  def bulk_push(
161
161
  self,
162
- events: List[BulkPushEventWithMetadata],
162
+ events: list[BulkPushEventWithMetadata],
163
163
  options: BulkPushEventOptions = BulkPushEventOptions(),
164
- ) -> List[Event]:
164
+ ) -> list[Event]:
165
165
  namespace = options.namespace or self.namespace
166
166
 
167
167
  bulk_request = BulkPushEventRequest(
@@ -190,7 +190,7 @@ class EventClient(BaseRestClient):
190
190
  self.events_service_client.PutLog(request, metadata=get_metadata(self.token))
191
191
 
192
192
  @tenacity_retry
193
- def stream(self, data: str | bytes, step_run_id: str) -> None:
193
+ def stream(self, data: str | bytes, step_run_id: str, index: int) -> None:
194
194
  if isinstance(data, str):
195
195
  data_bytes = data.encode("utf-8")
196
196
  elif isinstance(data, bytes):
@@ -202,11 +202,15 @@ class EventClient(BaseRestClient):
202
202
  stepRunId=step_run_id,
203
203
  createdAt=proto_timestamp_now(),
204
204
  message=data_bytes,
205
+ eventIndex=index,
205
206
  )
206
207
 
207
- self.events_service_client.PutStreamEvent(
208
- request, metadata=get_metadata(self.token)
209
- )
208
+ try:
209
+ self.events_service_client.PutStreamEvent(
210
+ request, metadata=get_metadata(self.token)
211
+ )
212
+ except Exception:
213
+ raise
210
214
 
211
215
  async def aio_list(
212
216
  self,
@@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict
8
8
 
9
9
  from hatchet_sdk.clients.listeners.pooled_listener import PooledListener
10
10
  from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
11
+ from hatchet_sdk.config import ClientConfig
11
12
  from hatchet_sdk.connection import new_conn
12
13
  from hatchet_sdk.contracts.v1.dispatcher_pb2 import (
13
14
  DurableEvent,
@@ -32,6 +33,7 @@ class RegisterDurableEventRequest(BaseModel):
32
33
  task_id: str
33
34
  signal_key: str
34
35
  conditions: list[SleepCondition | UserEventCondition]
36
+ config: ClientConfig
35
37
 
36
38
  def to_proto(self) -> RegisterDurableEventRequestProto:
37
39
  return RegisterDurableEventRequestProto(
@@ -39,12 +41,12 @@ class RegisterDurableEventRequest(BaseModel):
39
41
  signal_key=self.signal_key,
40
42
  conditions=DurableEventListenerConditions(
41
43
  sleep_conditions=[
42
- c.to_proto()
44
+ c.to_proto(self.config)
43
45
  for c in self.conditions
44
46
  if isinstance(c, SleepCondition)
45
47
  ],
46
48
  user_event_conditions=[
47
- c.to_proto()
49
+ c.to_proto(self.config)
48
50
  for c in self.conditions
49
51
  if isinstance(c, UserEventCondition)
50
52
  ],
@@ -252,10 +252,10 @@ class PooledListener(Generic[R, T, L], ABC):
252
252
  metadata=get_metadata(self.token),
253
253
  )
254
254
 
255
- except grpc.RpcError as e:
255
+ except grpc.RpcError as e: # noqa: PERF203
256
256
  if e.code() == grpc.StatusCode.UNAVAILABLE:
257
257
  retries = retries + 1
258
258
  else:
259
- raise ValueError(f"gRPC error: {e}")
259
+ raise ValueError("gRPC error") from e
260
260
 
261
261
  raise ValueError("Failed to connect to listener")
@@ -1,8 +1,9 @@
1
1
  import asyncio
2
+ from collections.abc import AsyncGenerator, Callable, Generator
2
3
  from enum import Enum
3
4
  from queue import Empty, Queue
4
5
  from threading import Thread
5
- from typing import Any, AsyncGenerator, Callable, Generator, Literal, TypeVar, cast
6
+ from typing import Any, Literal, TypeVar, cast
6
7
 
7
8
  import grpc
8
9
  from pydantic import BaseModel
@@ -129,8 +130,7 @@ class RunEventListener:
129
130
  thread.join()
130
131
 
131
132
  def __iter__(self) -> Generator[StepRunEvent, None, None]:
132
- for item in self.async_to_sync_thread(self.__aiter__()):
133
- yield item
133
+ yield from self.async_to_sync_thread(self.__aiter__())
134
134
 
135
135
  async def _generator(self) -> AsyncGenerator[StepRunEvent, None]:
136
136
  while True:
@@ -216,7 +216,7 @@ class RunEventListener:
216
216
  metadata=get_metadata(self.config.token),
217
217
  ),
218
218
  )
219
- elif self.additional_meta_kv is not None:
219
+ if self.additional_meta_kv is not None:
220
220
  return cast(
221
221
  AsyncGenerator[WorkflowEvent, None],
222
222
  self.client.SubscribeToWorkflowEvents(
@@ -227,14 +227,13 @@ class RunEventListener:
227
227
  metadata=get_metadata(self.config.token),
228
228
  ),
229
229
  )
230
- else:
231
- raise Exception("no listener method provided")
230
+ raise Exception("no listener method provided")
232
231
 
233
- except grpc.RpcError as e:
232
+ except grpc.RpcError as e: # noqa: PERF203
234
233
  if e.code() == grpc.StatusCode.UNAVAILABLE:
235
234
  retries = retries + 1
236
235
  else:
237
- raise ValueError(f"gRPC error: {e}")
236
+ raise ValueError("gRPC error") from e
238
237
 
239
238
  raise Exception("Failed to subscribe to workflow events")
240
239
 
@@ -1,5 +1,6 @@
1
1
  import json
2
- from typing import Any, AsyncIterator, cast
2
+ from collections.abc import AsyncIterator
3
+ from typing import Any, cast
3
4
 
4
5
  import grpc
5
6
  import grpc.aio
@@ -11,6 +12,11 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
11
12
  WorkflowRunEvent,
12
13
  )
13
14
  from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
15
+ from hatchet_sdk.exceptions import (
16
+ DedupeViolationError,
17
+ FailedTaskRunExceptionGroup,
18
+ TaskRunError,
19
+ )
14
20
 
15
21
  DEDUPE_MESSAGE = "DUPLICATE_WORKFLOW_RUN"
16
22
 
@@ -27,16 +33,18 @@ class PooledWorkflowRunListener(
27
33
  return response.workflowRunId
28
34
 
29
35
  async def aio_result(self, id: str) -> dict[str, Any]:
30
- from hatchet_sdk.clients.admin import DedupeViolationErr
31
-
32
36
  event = await self.subscribe(id)
33
37
  errors = [result.error for result in event.results if result.error]
38
+ workflow_run_id = event.workflowRunId
34
39
 
35
40
  if errors:
36
41
  if DEDUPE_MESSAGE in errors[0]:
37
- raise DedupeViolationErr(errors[0])
38
- else:
39
- raise Exception(f"Workflow Errors: {errors}")
42
+ raise DedupeViolationError(errors[0])
43
+
44
+ raise FailedTaskRunExceptionGroup(
45
+ f"Workflow run {workflow_run_id} failed.",
46
+ [TaskRunError.deserialize(e) for e in errors],
47
+ )
40
48
 
41
49
  return {
42
50
  result.stepReadableId: json.loads(result.output)
@@ -2,7 +2,8 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
- from typing import Generic, Mapping, Optional, TypeVar
5
+ from collections.abc import Mapping
6
+ from typing import Generic, TypeVar
6
7
 
7
8
  from pydantic import BaseModel, Field, StrictBytes, StrictInt
8
9
 
@@ -15,7 +16,7 @@ class ApiResponse(BaseModel, Generic[T]):
15
16
  """
16
17
 
17
18
  status_code: StrictInt = Field(description="HTTP status code")
18
- headers: Optional[Mapping[str, str]] = Field(None, description="HTTP headers")
19
+ headers: Mapping[str, str] | None = Field(None, description="HTTP headers")
19
20
  data: T = Field(description="Deserialized data given the data type")
20
21
  raw_data: StrictBytes = Field(description="Raw data (HTTP response body)")
21
22
 
@@ -1,4 +1,5 @@
1
- from typing import Callable, ParamSpec, TypeVar
1
+ from collections.abc import Callable
2
+ from typing import ParamSpec, TypeVar
2
3
 
3
4
  import grpc
4
5
  import tenacity
@@ -28,12 +29,9 @@ def tenacity_alert_retry(retry_state: tenacity.RetryCallState) -> None:
28
29
 
29
30
 
30
31
  def tenacity_should_retry(ex: BaseException) -> bool:
31
- if isinstance(ex, (grpc.aio.AioRpcError, grpc.RpcError)):
32
- if ex.code() in [
32
+ if isinstance(ex, grpc.aio.AioRpcError | grpc.RpcError):
33
+ return ex.code() not in [
33
34
  grpc.StatusCode.UNIMPLEMENTED,
34
35
  grpc.StatusCode.NOT_FOUND,
35
- ]:
36
- return False
37
- return True
38
- else:
39
- return False
36
+ ]
37
+ return False
hatchet_sdk/config.py CHANGED
@@ -82,6 +82,8 @@ class ClientConfig(BaseSettings):
82
82
  enable_force_kill_sync_threads: bool = False
83
83
  enable_thread_pool_monitoring: bool = False
84
84
 
85
+ terminate_worker_after_num_tasks: int | None = None
86
+
85
87
  @model_validator(mode="after")
86
88
  def validate_token_and_tenant(self) -> "ClientConfig":
87
89
  if not self.token:
hatchet_sdk/connection.py CHANGED
@@ -22,7 +22,8 @@ def new_conn(config: ClientConfig, aio: bool) -> grpc.Channel | grpc.aio.Channel
22
22
  root: bytes | None = None
23
23
 
24
24
  if config.tls_config.root_ca_file:
25
- root = open(config.tls_config.root_ca_file, "rb").read()
25
+ with open(config.tls_config.root_ca_file, "rb") as f:
26
+ root = f.read()
26
27
 
27
28
  credentials = grpc.ssl_channel_credentials(root_certificates=root)
28
29
  elif config.tls_config.strategy == "mtls":
@@ -30,9 +31,14 @@ def new_conn(config: ClientConfig, aio: bool) -> grpc.Channel | grpc.aio.Channel
30
31
  assert config.tls_config.key_file
31
32
  assert config.tls_config.cert_file
32
33
 
33
- root = open(config.tls_config.root_ca_file, "rb").read()
34
- private_key = open(config.tls_config.key_file, "rb").read()
35
- certificate_chain = open(config.tls_config.cert_file, "rb").read()
34
+ with open(config.tls_config.root_ca_file, "rb") as f:
35
+ root = f.read()
36
+
37
+ with open(config.tls_config.key_file, "rb") as f:
38
+ private_key = f.read()
39
+
40
+ with open(config.tls_config.cert_file, "rb") as f:
41
+ certificate_chain = f.read()
36
42
 
37
43
  credentials = grpc.ssl_channel_credentials(
38
44
  root_certificates=root,