hatchet-sdk 1.6.2__py3-none-any.whl → 1.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

@@ -9,7 +9,11 @@ import grpc
9
9
  import grpc.aio
10
10
  from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
11
11
 
12
- from hatchet_sdk.clients.event_ts import ThreadSafeEvent, read_with_interrupt
12
+ from hatchet_sdk.clients.event_ts import (
13
+ ThreadSafeEvent,
14
+ UnexpectedEOF,
15
+ read_with_interrupt,
16
+ )
13
17
  from hatchet_sdk.clients.events import proto_timestamp_now
14
18
  from hatchet_sdk.clients.listeners.run_event_listener import (
15
19
  DEFAULT_ACTION_LISTENER_RETRY_INTERVAL,
@@ -84,6 +88,9 @@ class ActionType(str, Enum):
84
88
  START_GET_GROUP_KEY = "START_GET_GROUP_KEY"
85
89
 
86
90
 
91
+ ActionKey = str
92
+
93
+
87
94
  class Action(BaseModel):
88
95
  worker_id: str
89
96
  tenant_id: str
@@ -137,6 +144,18 @@ class Action(BaseModel):
137
144
 
138
145
  return {k: v for k, v in attrs.items() if v}
139
146
 
147
+ @property
148
+ def key(self) -> ActionKey:
149
+ """
150
+ This key is used to uniquely identify a single step run by its id + retry count.
151
+ It's used when storing references to a task, a context, etc. in a dictionary so that
152
+ we can look up those items in the dictionary by a unique key.
153
+ """
154
+ if self.action_type == ActionType.START_GET_GROUP_KEY:
155
+ return f"{self.get_group_key_run_id}/{self.retry_count}"
156
+ else:
157
+ return f"{self.step_run_id}/{self.retry_count}"
158
+
140
159
 
141
160
  def parse_additional_metadata(additional_metadata: str) -> JSONSerializableMapping:
142
161
  try:
@@ -275,15 +294,17 @@ class ActionListener:
275
294
 
276
295
  break
277
296
 
278
- assigned_action, _, is_eof = t.result()
297
+ result = t.result()
279
298
 
280
- if is_eof:
299
+ if isinstance(result, UnexpectedEOF):
281
300
  logger.debug("Handling EOF in Action Listener")
282
301
  self.retries = self.retries + 1
283
302
  break
284
303
 
285
304
  self.retries = 0
286
305
 
306
+ assigned_action = result.data
307
+
287
308
  try:
288
309
  action_payload = (
289
310
  ActionPayload()
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import Callable, TypeVar, cast, overload
2
+ from typing import Callable, Generic, TypeVar, cast, overload
3
3
 
4
4
  import grpc.aio
5
5
  from grpc._cython import cygrpc # type: ignore[attr-defined]
@@ -29,12 +29,23 @@ TRequest = TypeVar("TRequest")
29
29
  TResponse = TypeVar("TResponse")
30
30
 
31
31
 
32
+ class ReadWithInterruptResult(Generic[TResponse]):
33
+ def __init__(self, data: TResponse, key: str):
34
+ self.data = data
35
+ self.key = key
36
+
37
+
38
+ class UnexpectedEOF:
39
+ def __init__(self) -> None:
40
+ pass
41
+
42
+
32
43
  @overload
33
44
  async def read_with_interrupt(
34
45
  listener: grpc.aio.UnaryStreamCall[TRequest, TResponse],
35
46
  interrupt: ThreadSafeEvent,
36
47
  key_generator: Callable[[TResponse], str],
37
- ) -> tuple[TResponse, str, bool]: ...
48
+ ) -> ReadWithInterruptResult[TResponse] | UnexpectedEOF: ...
38
49
 
39
50
 
40
51
  @overload
@@ -42,23 +53,23 @@ async def read_with_interrupt(
42
53
  listener: grpc.aio.UnaryStreamCall[TRequest, TResponse],
43
54
  interrupt: ThreadSafeEvent,
44
55
  key_generator: None = None,
45
- ) -> tuple[TResponse, None, bool]: ...
56
+ ) -> ReadWithInterruptResult[TResponse] | UnexpectedEOF: ...
46
57
 
47
58
 
48
59
  async def read_with_interrupt(
49
60
  listener: grpc.aio.UnaryStreamCall[TRequest, TResponse],
50
61
  interrupt: ThreadSafeEvent,
51
62
  key_generator: Callable[[TResponse], str] | None = None,
52
- ) -> tuple[TResponse, str | None, bool]:
63
+ ) -> ReadWithInterruptResult[TResponse] | UnexpectedEOF:
53
64
  try:
54
65
  result = cast(TResponse, await listener.read())
55
66
 
56
67
  if result is cygrpc.EOF:
57
68
  logger.warning("Received EOF from engine")
58
- return cast(TResponse, None), None, True
69
+ return UnexpectedEOF()
59
70
 
60
- key = key_generator(result) if key_generator else None
71
+ key = key_generator(result) if key_generator else ""
61
72
 
62
- return result, key, False
73
+ return ReadWithInterruptResult(data=result, key=key)
63
74
  finally:
64
75
  interrupt.set()
@@ -149,6 +149,7 @@ class EventClient:
149
149
  ).events
150
150
  )
151
151
 
152
+ @tenacity_retry
152
153
  def log(self, message: str, step_run_id: str) -> None:
153
154
  request = PutLogRequest(
154
155
  stepRunId=step_run_id,
@@ -158,6 +159,7 @@ class EventClient:
158
159
 
159
160
  self.client.PutLog(request, metadata=get_metadata(self.token))
160
161
 
162
+ @tenacity_retry
161
163
  def stream(self, data: str | bytes, step_run_id: str) -> None:
162
164
  if isinstance(data, str):
163
165
  data_bytes = data.encode("utf-8")
@@ -6,7 +6,11 @@ from typing import Generic, Literal, TypeVar
6
6
  import grpc
7
7
  import grpc.aio
8
8
 
9
- from hatchet_sdk.clients.event_ts import ThreadSafeEvent, read_with_interrupt
9
+ from hatchet_sdk.clients.event_ts import (
10
+ ThreadSafeEvent,
11
+ UnexpectedEOF,
12
+ read_with_interrupt,
13
+ )
10
14
  from hatchet_sdk.config import ClientConfig
11
15
  from hatchet_sdk.logger import logger
12
16
  from hatchet_sdk.metadata import get_metadata
@@ -130,18 +134,18 @@ class PooledListener(Generic[R, T, L], ABC):
130
134
  await asyncio.sleep(DEFAULT_LISTENER_RETRY_INTERVAL)
131
135
  break
132
136
 
133
- event, key, is_eof = t.result()
137
+ event = t.result()
134
138
 
135
- if is_eof:
139
+ if isinstance(event, UnexpectedEOF):
136
140
  logger.debug(
137
141
  f"Handling EOF in Pooled Listener {self.__class__.__name__}"
138
142
  )
139
143
  break
140
144
 
141
- subscriptions = self.to_subscriptions.get(key, [])
145
+ subscriptions = self.to_subscriptions.get(event.key, [])
142
146
 
143
147
  for subscription_id in subscriptions:
144
- await self.events[subscription_id].put(event)
148
+ await self.events[subscription_id].put(event.data)
145
149
 
146
150
  except grpc.RpcError as e:
147
151
  logger.debug(f"grpc error in listener: {e}")
@@ -127,15 +127,18 @@ class Context:
127
127
  def workflow_run_id(self) -> str:
128
128
  return self.action.workflow_run_id
129
129
 
130
+ def _set_cancellation_flag(self) -> None:
131
+ self.exit_flag = True
132
+
130
133
  def cancel(self) -> None:
131
134
  logger.debug("cancelling step...")
132
135
  self.runs_client.cancel(self.step_run_id)
133
- self.exit_flag = True
136
+ self._set_cancellation_flag()
134
137
 
135
138
  async def aio_cancel(self) -> None:
136
139
  logger.debug("cancelling step...")
137
140
  await self.runs_client.aio_cancel(self.step_run_id)
138
- self.exit_flag = True
141
+ self._set_cancellation_flag()
139
142
 
140
143
  # done returns true if the context has been cancelled
141
144
  def done(self) -> bool:
@@ -26,11 +26,6 @@ from hatchet_sdk.utils.typing import JSONSerializableMapping
26
26
  class CreateCronTriggerConfig(BaseModel):
27
27
  """
28
28
  Schema for creating a workflow run triggered by a cron.
29
-
30
- Attributes:
31
- expression (str): The cron expression defining the schedule.
32
- input (dict): The input data for the cron workflow.
33
- additional_metadata (dict[str, str]): Additional metadata associated with the cron trigger (e.g. {"key1": "value1", "key2": "value2"}).
34
29
  """
35
30
 
36
31
  expression: str
@@ -43,14 +38,11 @@ class CreateCronTriggerConfig(BaseModel):
43
38
  """
44
39
  Validates the cron expression to ensure it adheres to the expected format.
45
40
 
46
- Args:
47
- v (str): The cron expression to validate.
41
+ :param v: The cron expression to validate.
48
42
 
49
- Raises:
50
- ValueError: If the expression is invalid.
43
+ :raises ValueError: If the expression is invalid
51
44
 
52
- Returns:
53
- str: The validated cron expression.
45
+ :return: The validated cron expression.
54
46
  """
55
47
  if not v:
56
48
  raise ValueError("Cron expression is required")
@@ -72,6 +64,10 @@ class CreateCronTriggerConfig(BaseModel):
72
64
 
73
65
 
74
66
  class CronClient(BaseRestClient):
67
+ """
68
+ The cron client is a client for managing cron workflows within Hatchet.
69
+ """
70
+
75
71
  def _wra(self, client: ApiClient) -> WorkflowRunApi:
76
72
  return WorkflowRunApi(client)
77
73
 
@@ -88,17 +84,16 @@ class CronClient(BaseRestClient):
88
84
  priority: int | None = None,
89
85
  ) -> CronWorkflows:
90
86
  """
91
- Asynchronously creates a new workflow cron trigger.
87
+ Create a new workflow cron trigger.
92
88
 
93
- Args:
94
- workflow_name (str): The name of the workflow to trigger.
95
- cron_name (str): The name of the cron trigger.
96
- expression (str): The cron expression defining the schedule.
97
- input (dict): The input data for the cron workflow.
98
- additional_metadata (dict[str, str]): Additional metadata associated with the cron trigger (e.g. {"key1": "value1", "key2": "value2"}).
89
+ :param workflow_name: The name of the workflow to trigger.
90
+ :param cron_name: The name of the cron trigger.
91
+ :param expression: The cron expression defining the schedule.
92
+ :param input: The input data for the cron workflow.
93
+ :param additional_metadata: Additional metadata associated with the cron trigger.
94
+ :param priority: The priority of the cron workflow trigger.
99
95
 
100
- Returns:
101
- CronWorkflows: The created cron workflow instance.
96
+ :return: The created cron workflow instance.
102
97
  """
103
98
  validated_input = CreateCronTriggerConfig(
104
99
  expression=expression, input=input, additional_metadata=additional_metadata
@@ -126,6 +121,18 @@ class CronClient(BaseRestClient):
126
121
  additional_metadata: JSONSerializableMapping,
127
122
  priority: int | None = None,
128
123
  ) -> CronWorkflows:
124
+ """
125
+ Create a new workflow cron trigger.
126
+
127
+ :param workflow_name: The name of the workflow to trigger.
128
+ :param cron_name: The name of the cron trigger.
129
+ :param expression: The cron expression defining the schedule.
130
+ :param input: The input data for the cron workflow.
131
+ :param additional_metadata: Additional metadata associated with the cron trigger.
132
+ :param priority: The priority of the cron workflow trigger.
133
+
134
+ :return: The created cron workflow instance.
135
+ """
129
136
  return await asyncio.to_thread(
130
137
  self.create,
131
138
  workflow_name,
@@ -138,10 +145,10 @@ class CronClient(BaseRestClient):
138
145
 
139
146
  def delete(self, cron_id: str) -> None:
140
147
  """
141
- Asynchronously deletes a workflow cron trigger.
148
+ Delete a workflow cron trigger.
142
149
 
143
- Args:
144
- cron_id (str): The cron trigger ID or CronWorkflows instance to delete.
150
+ :param cron_id: The ID of the cron trigger to delete.
151
+ :return: None
145
152
  """
146
153
  with self.client() as client:
147
154
  self._wa(client).workflow_cron_delete(
@@ -149,6 +156,12 @@ class CronClient(BaseRestClient):
149
156
  )
150
157
 
151
158
  async def aio_delete(self, cron_id: str) -> None:
159
+ """
160
+ Delete a workflow cron trigger.
161
+
162
+ :param cron_id: The ID of the cron trigger to delete.
163
+ :return: None
164
+ """
152
165
  return await asyncio.to_thread(self.delete, cron_id)
153
166
 
154
167
  async def aio_list(
@@ -161,18 +174,16 @@ class CronClient(BaseRestClient):
161
174
  order_by_direction: WorkflowRunOrderByDirection | None = None,
162
175
  ) -> CronWorkflowsList:
163
176
  """
164
- Synchronously retrieves a list of all workflow cron triggers matching the criteria.
177
+ Retrieve a list of all workflow cron triggers matching the criteria.
165
178
 
166
- Args:
167
- offset (int | None): The offset to start the list from.
168
- limit (int | None): The maximum number of items to return.
169
- workflow_id (str | None): The ID of the workflow to filter by.
170
- additional_metadata (list[str] | None): Filter by additional metadata keys (e.g. ["key1:value1", "key2:value2"]).
171
- order_by_field (CronWorkflowsOrderByField | None): The field to order the list by.
172
- order_by_direction (WorkflowRunOrderByDirection | None): The direction to order the list by.
179
+ :param offset: The offset to start the list from.
180
+ :param limit: The maximum number of items to return.
181
+ :param workflow_id: The ID of the workflow to filter by.
182
+ :param additional_metadata: Filter by additional metadata keys.
183
+ :param order_by_field: The field to order the list by.
184
+ :param order_by_direction: The direction to order the list by.
173
185
 
174
- Returns:
175
- CronWorkflowsList: A list of cron workflows.
186
+ :return: A list of cron workflows.
176
187
  """
177
188
  return await asyncio.to_thread(
178
189
  self.list,
@@ -194,18 +205,16 @@ class CronClient(BaseRestClient):
194
205
  order_by_direction: WorkflowRunOrderByDirection | None = None,
195
206
  ) -> CronWorkflowsList:
196
207
  """
197
- Asynchronously retrieves a list of all workflow cron triggers matching the criteria.
208
+ Retrieve a list of all workflow cron triggers matching the criteria.
198
209
 
199
- Args:
200
- offset (int | None): The offset to start the list from.
201
- limit (int | None): The maximum number of items to return.
202
- workflow_id (str | None): The ID of the workflow to filter by.
203
- additional_metadata (list[str] | None): Filter by additional metadata keys (e.g. ["key1:value1", "key2:value2"]).
204
- order_by_field (CronWorkflowsOrderByField | None): The field to order the list by.
205
- order_by_direction (WorkflowRunOrderByDirection | None): The direction to order the list by.
210
+ :param offset: The offset to start the list from.
211
+ :param limit: The maximum number of items to return.
212
+ :param workflow_id: The ID of the workflow to filter by.
213
+ :param additional_metadata: Filter by additional metadata keys.
214
+ :param order_by_field: The field to order the list by.
215
+ :param order_by_direction: The direction to order the list by.
206
216
 
207
- Returns:
208
- CronWorkflowsList: A list of cron workflows.
217
+ :return: A list of cron workflows.
209
218
  """
210
219
  with self.client() as client:
211
220
  return self._wa(client).cron_workflow_list(
@@ -222,13 +231,10 @@ class CronClient(BaseRestClient):
222
231
 
223
232
  def get(self, cron_id: str) -> CronWorkflows:
224
233
  """
225
- Asynchronously retrieves a specific workflow cron trigger by ID.
226
-
227
- Args:
228
- cron_id (str): The cron trigger ID or CronWorkflows instance to retrieve.
234
+ Retrieve a specific workflow cron trigger by ID.
229
235
 
230
- Returns:
231
- CronWorkflows: The requested cron workflow instance.
236
+ :param cron_id: The cron trigger ID or CronWorkflows instance to retrieve.
237
+ :return: The requested cron workflow instance.
232
238
  """
233
239
  with self.client() as client:
234
240
  return self._wa(client).workflow_cron_get(
@@ -237,12 +243,9 @@ class CronClient(BaseRestClient):
237
243
 
238
244
  async def aio_get(self, cron_id: str) -> CronWorkflows:
239
245
  """
240
- Synchronously retrieves a specific workflow cron trigger by ID.
241
-
242
- Args:
243
- cron_id (str): The cron trigger ID or CronWorkflows instance to retrieve.
246
+ Retrieve a specific workflow cron trigger by ID.
244
247
 
245
- Returns:
246
- CronWorkflows: The requested cron workflow instance.
248
+ :param cron_id: The cron trigger ID or CronWorkflows instance to retrieve.
249
+ :return: The requested cron workflow instance.
247
250
  """
248
251
  return await asyncio.to_thread(self.get, cron_id)
@@ -7,12 +7,28 @@ from hatchet_sdk.clients.v1.api_client import BaseRestClient
7
7
 
8
8
 
9
9
  class LogsClient(BaseRestClient):
10
+ """
11
+ The logs client is a client for interacting with Hatchet's logs API.
12
+ """
13
+
10
14
  def _la(self, client: ApiClient) -> LogApi:
11
15
  return LogApi(client)
12
16
 
13
17
  def list(self, task_run_id: str) -> V1LogLineList:
18
+ """
19
+ List log lines for a given task run.
20
+
21
+ :param task_run_id: The ID of the task run to list logs for.
22
+ :return: A list of log lines for the specified task run.
23
+ """
14
24
  with self.client() as client:
15
25
  return self._la(client).v1_log_line_list(task=task_run_id)
16
26
 
17
27
  async def aio_list(self, task_run_id: str) -> V1LogLineList:
28
+ """
29
+ List log lines for a given task run.
30
+
31
+ :param task_run_id: The ID of the task run to list logs for.
32
+ :return: A list of log lines for the specified task run.
33
+ """
18
34
  return await asyncio.to_thread(self.list, task_run_id)
@@ -17,6 +17,10 @@ from hatchet_sdk.utils.typing import JSONSerializableMapping
17
17
 
18
18
 
19
19
  class MetricsClient(BaseRestClient):
20
+ """
21
+ The metrics client is a client for reading metrics out of Hatchet's metrics API.
22
+ """
23
+
20
24
  def _wa(self, client: ApiClient) -> WorkflowApi:
21
25
  return WorkflowApi(client)
22
26
 
@@ -29,6 +33,15 @@ class MetricsClient(BaseRestClient):
29
33
  status: WorkflowRunStatus | None = None,
30
34
  group_key: str | None = None,
31
35
  ) -> WorkflowMetrics:
36
+ """
37
+ Retrieve workflow metrics for a given workflow ID.
38
+
39
+ :param workflow_id: The ID of the workflow to retrieve metrics for.
40
+ :param status: The status of the workflow run to filter by.
41
+ :param group_key: The key to group the metrics by.
42
+
43
+ :return: Workflow metrics for the specified workflow ID.
44
+ """
32
45
  with self.client() as client:
33
46
  return self._wa(client).workflow_get_metrics(
34
47
  workflow=workflow_id, status=status, group_key=group_key
@@ -40,6 +53,15 @@ class MetricsClient(BaseRestClient):
40
53
  status: WorkflowRunStatus | None = None,
41
54
  group_key: str | None = None,
42
55
  ) -> WorkflowMetrics:
56
+ """
57
+ Retrieve workflow metrics for a given workflow ID.
58
+
59
+ :param workflow_id: The ID of the workflow to retrieve metrics for.
60
+ :param status: The status of the workflow run to filter by.
61
+ :param group_key: The key to group the metrics by.
62
+
63
+ :return: Workflow metrics for the specified workflow ID.
64
+ """
43
65
  return await asyncio.to_thread(
44
66
  self.get_workflow_metrics, workflow_id, status, group_key
45
67
  )
@@ -49,6 +71,14 @@ class MetricsClient(BaseRestClient):
49
71
  workflow_ids: list[str] | None = None,
50
72
  additional_metadata: JSONSerializableMapping | None = None,
51
73
  ) -> TenantQueueMetrics:
74
+ """
75
+ Retrieve queue metrics for a set of workflow ids and additional metadata.
76
+
77
+ :param workflow_ids: A list of workflow IDs to retrieve metrics for.
78
+ :param additional_metadata: Additional metadata to filter the metrics by.
79
+
80
+ :return: Workflow metrics for the specified workflow IDs.
81
+ """
52
82
  with self.client() as client:
53
83
  return self._wa(client).tenant_get_queue_metrics(
54
84
  tenant=self.client_config.tenant_id,
@@ -63,15 +93,33 @@ class MetricsClient(BaseRestClient):
63
93
  workflow_ids: list[str] | None = None,
64
94
  additional_metadata: JSONSerializableMapping | None = None,
65
95
  ) -> TenantQueueMetrics:
96
+ """
97
+ Retrieve queue metrics for a set of workflow ids and additional metadata.
98
+
99
+ :param workflow_ids: A list of workflow IDs to retrieve metrics for.
100
+ :param additional_metadata: Additional metadata to filter the metrics by.
101
+
102
+ :return: Workflow metrics for the specified workflow IDs.
103
+ """
66
104
  return await asyncio.to_thread(
67
105
  self.get_queue_metrics, workflow_ids, additional_metadata
68
106
  )
69
107
 
70
108
  def get_task_metrics(self) -> TenantStepRunQueueMetrics:
109
+ """
110
+ Retrieve queue metrics
111
+
112
+ :return: Step run queue metrics for the tenant
113
+ """
71
114
  with self.client() as client:
72
115
  return self._ta(client).tenant_get_step_run_queue_metrics(
73
116
  tenant=self.client_config.tenant_id
74
117
  )
75
118
 
76
119
  async def aio_get_task_metrics(self) -> TenantStepRunQueueMetrics:
120
+ """
121
+ Retrieve queue metrics
122
+
123
+ :return: Step run queue metrics for the tenant
124
+ """
77
125
  return await asyncio.to_thread(self.get_task_metrics)
@@ -12,6 +12,10 @@ from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
12
12
 
13
13
 
14
14
  class RateLimitsClient(BaseRestClient):
15
+ """
16
+ The rate limits client is a wrapper for Hatchet's gRPC API that makes it easier to work with rate limits in Hatchet.
17
+ """
18
+
15
19
  @tenacity_retry
16
20
  def put(
17
21
  self,
@@ -19,6 +23,16 @@ class RateLimitsClient(BaseRestClient):
19
23
  limit: int,
20
24
  duration: RateLimitDuration = RateLimitDuration.SECOND,
21
25
  ) -> None:
26
+ """
27
+ Put a rate limit for a given key.
28
+
29
+ :param key: The key to set the rate limit for.
30
+ :param limit: The rate limit to set.
31
+ :param duration: The duration of the rate limit.
32
+
33
+ :return: None
34
+ """
35
+
22
36
  duration_proto = convert_python_enum_to_proto(
23
37
  duration, workflow_protos.RateLimitDuration
24
38
  )
@@ -42,4 +56,14 @@ class RateLimitsClient(BaseRestClient):
42
56
  limit: int,
43
57
  duration: RateLimitDuration = RateLimitDuration.SECOND,
44
58
  ) -> None:
59
+ """
60
+ Put a rate limit for a given key.
61
+
62
+ :param key: The key to set the rate limit for.
63
+ :param limit: The rate limit to set.
64
+ :param duration: The duration of the rate limit.
65
+
66
+ :return: None
67
+ """
68
+
45
69
  await asyncio.to_thread(self.put, key, limit, duration)