hatchet-sdk 1.0.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (73) hide show
  1. hatchet_sdk/__init__.py +32 -16
  2. hatchet_sdk/client.py +25 -63
  3. hatchet_sdk/clients/admin.py +203 -142
  4. hatchet_sdk/clients/dispatcher/action_listener.py +42 -42
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +18 -16
  6. hatchet_sdk/clients/durable_event_listener.py +327 -0
  7. hatchet_sdk/clients/rest/__init__.py +12 -1
  8. hatchet_sdk/clients/rest/api/log_api.py +258 -0
  9. hatchet_sdk/clients/rest/api/task_api.py +32 -6
  10. hatchet_sdk/clients/rest/api/workflow_runs_api.py +626 -0
  11. hatchet_sdk/clients/rest/models/__init__.py +12 -1
  12. hatchet_sdk/clients/rest/models/v1_log_line.py +94 -0
  13. hatchet_sdk/clients/rest/models/v1_log_line_level.py +39 -0
  14. hatchet_sdk/clients/rest/models/v1_log_line_list.py +110 -0
  15. hatchet_sdk/clients/rest/models/v1_task_summary.py +80 -64
  16. hatchet_sdk/clients/rest/models/v1_trigger_workflow_run_request.py +95 -0
  17. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name.py +98 -0
  18. hatchet_sdk/clients/rest/models/v1_workflow_run_display_name_list.py +114 -0
  19. hatchet_sdk/clients/rest/models/workflow_run_shape_item_for_workflow_run_details.py +9 -4
  20. hatchet_sdk/clients/rest/models/workflow_runs_metrics.py +5 -1
  21. hatchet_sdk/clients/run_event_listener.py +0 -1
  22. hatchet_sdk/clients/v1/api_client.py +81 -0
  23. hatchet_sdk/context/context.py +86 -159
  24. hatchet_sdk/contracts/dispatcher_pb2_grpc.py +1 -1
  25. hatchet_sdk/contracts/events_pb2.py +2 -2
  26. hatchet_sdk/contracts/events_pb2_grpc.py +1 -1
  27. hatchet_sdk/contracts/v1/dispatcher_pb2.py +36 -0
  28. hatchet_sdk/contracts/v1/dispatcher_pb2.pyi +38 -0
  29. hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +145 -0
  30. hatchet_sdk/contracts/v1/shared/condition_pb2.py +39 -0
  31. hatchet_sdk/contracts/v1/shared/condition_pb2.pyi +72 -0
  32. hatchet_sdk/contracts/v1/shared/condition_pb2_grpc.py +29 -0
  33. hatchet_sdk/contracts/v1/workflows_pb2.py +67 -0
  34. hatchet_sdk/contracts/v1/workflows_pb2.pyi +228 -0
  35. hatchet_sdk/contracts/v1/workflows_pb2_grpc.py +234 -0
  36. hatchet_sdk/contracts/workflows_pb2_grpc.py +1 -1
  37. hatchet_sdk/features/cron.py +91 -121
  38. hatchet_sdk/features/logs.py +16 -0
  39. hatchet_sdk/features/metrics.py +75 -0
  40. hatchet_sdk/features/rate_limits.py +45 -0
  41. hatchet_sdk/features/runs.py +221 -0
  42. hatchet_sdk/features/scheduled.py +114 -131
  43. hatchet_sdk/features/workers.py +41 -0
  44. hatchet_sdk/features/workflows.py +55 -0
  45. hatchet_sdk/hatchet.py +463 -165
  46. hatchet_sdk/opentelemetry/instrumentor.py +8 -13
  47. hatchet_sdk/rate_limit.py +33 -39
  48. hatchet_sdk/runnables/contextvars.py +12 -0
  49. hatchet_sdk/runnables/standalone.py +192 -0
  50. hatchet_sdk/runnables/task.py +144 -0
  51. hatchet_sdk/runnables/types.py +138 -0
  52. hatchet_sdk/runnables/workflow.py +771 -0
  53. hatchet_sdk/utils/aio_utils.py +0 -79
  54. hatchet_sdk/utils/proto_enums.py +0 -7
  55. hatchet_sdk/utils/timedelta_to_expression.py +23 -0
  56. hatchet_sdk/utils/typing.py +2 -2
  57. hatchet_sdk/v0/clients/rest_client.py +9 -0
  58. hatchet_sdk/v0/worker/action_listener_process.py +18 -2
  59. hatchet_sdk/waits.py +120 -0
  60. hatchet_sdk/worker/action_listener_process.py +64 -30
  61. hatchet_sdk/worker/runner/run_loop_manager.py +35 -26
  62. hatchet_sdk/worker/runner/runner.py +72 -55
  63. hatchet_sdk/worker/runner/utils/capture_logs.py +3 -11
  64. hatchet_sdk/worker/worker.py +155 -118
  65. hatchet_sdk/workflow_run.py +4 -5
  66. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/METADATA +1 -2
  67. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/RECORD +69 -43
  68. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/entry_points.txt +2 -0
  69. hatchet_sdk/clients/rest_client.py +0 -636
  70. hatchet_sdk/semver.py +0 -30
  71. hatchet_sdk/worker/runner/utils/error_with_traceback.py +0 -6
  72. hatchet_sdk/workflow.py +0 -527
  73. {hatchet_sdk-1.0.0.dist-info → hatchet_sdk-1.0.1.dist-info}/WHEEL +0 -0
@@ -1,8 +1,11 @@
1
- from typing import List, Union
2
-
3
1
  from pydantic import BaseModel, Field, field_validator
4
2
 
5
- from hatchet_sdk.client import Client
3
+ from hatchet_sdk.clients.rest.api.workflow_api import WorkflowApi
4
+ from hatchet_sdk.clients.rest.api.workflow_run_api import WorkflowRunApi
5
+ from hatchet_sdk.clients.rest.api_client import ApiClient
6
+ from hatchet_sdk.clients.rest.models.create_cron_workflow_trigger_request import (
7
+ CreateCronWorkflowTriggerRequest,
8
+ )
6
9
  from hatchet_sdk.clients.rest.models.cron_workflows import CronWorkflows
7
10
  from hatchet_sdk.clients.rest.models.cron_workflows_list import CronWorkflowsList
8
11
  from hatchet_sdk.clients.rest.models.cron_workflows_order_by_field import (
@@ -11,10 +14,14 @@ from hatchet_sdk.clients.rest.models.cron_workflows_order_by_field import (
11
14
  from hatchet_sdk.clients.rest.models.workflow_run_order_by_direction import (
12
15
  WorkflowRunOrderByDirection,
13
16
  )
17
+ from hatchet_sdk.clients.v1.api_client import (
18
+ BaseRestClient,
19
+ maybe_additional_metadata_to_kv,
20
+ )
14
21
  from hatchet_sdk.utils.typing import JSONSerializableMapping
15
22
 
16
23
 
17
- class CreateCronTriggerJSONSerializableMapping(BaseModel):
24
+ class CreateCronTriggerConfig(BaseModel):
18
25
  """
19
26
  Schema for creating a workflow run triggered by a cron.
20
27
 
@@ -62,27 +69,14 @@ class CreateCronTriggerJSONSerializableMapping(BaseModel):
62
69
  return v
63
70
 
64
71
 
65
- class CronClient:
66
- """
67
- Client for managing workflow cron triggers synchronously.
68
-
69
- Attributes:
70
- _client (Client): The underlying client used to interact with the REST API.
71
- aio (CronClientAsync): Asynchronous counterpart of CronClient.
72
- """
73
-
74
- _client: Client
75
-
76
- def __init__(self, _client: Client):
77
- """
78
- Initializes the CronClient with a given Client instance.
72
+ class CronClient(BaseRestClient):
73
+ def _wra(self, client: ApiClient) -> WorkflowRunApi:
74
+ return WorkflowRunApi(client)
79
75
 
80
- Args:
81
- _client (Client): The client instance to be used for REST interactions.
82
- """
83
- self._client = _client
76
+ def _wa(self, client: ApiClient) -> WorkflowApi:
77
+ return WorkflowApi(client)
84
78
 
85
- def create(
79
+ async def aio_create(
86
80
  self,
87
81
  workflow_name: str,
88
82
  cron_name: str,
@@ -91,7 +85,7 @@ class CronClient:
91
85
  additional_metadata: JSONSerializableMapping,
92
86
  ) -> CronWorkflows:
93
87
  """
94
- Creates a new workflow cron trigger.
88
+ Asynchronously creates a new workflow cron trigger.
95
89
 
96
90
  Args:
97
91
  workflow_name (str): The name of the workflow to trigger.
@@ -103,42 +97,65 @@ class CronClient:
103
97
  Returns:
104
98
  CronWorkflows: The created cron workflow instance.
105
99
  """
106
- validated_input = CreateCronTriggerJSONSerializableMapping(
100
+ validated_input = CreateCronTriggerConfig(
107
101
  expression=expression, input=input, additional_metadata=additional_metadata
108
102
  )
109
103
 
110
- return self._client.rest.cron_create(
104
+ async with self.client() as client:
105
+ return await self._wra(client).cron_workflow_trigger_create(
106
+ tenant=self.client_config.tenant_id,
107
+ workflow=workflow_name,
108
+ create_cron_workflow_trigger_request=CreateCronWorkflowTriggerRequest(
109
+ cronName=cron_name,
110
+ cronExpression=validated_input.expression,
111
+ input=dict(validated_input.input),
112
+ additionalMetadata=dict(validated_input.additional_metadata),
113
+ ),
114
+ )
115
+
116
+ def create(
117
+ self,
118
+ workflow_name: str,
119
+ cron_name: str,
120
+ expression: str,
121
+ input: JSONSerializableMapping,
122
+ additional_metadata: JSONSerializableMapping,
123
+ ) -> CronWorkflows:
124
+ return self._run_async_from_sync(
125
+ self.aio_create,
111
126
  workflow_name,
112
127
  cron_name,
113
- validated_input.expression,
114
- validated_input.input,
115
- validated_input.additional_metadata,
128
+ expression,
129
+ input,
130
+ additional_metadata,
116
131
  )
117
132
 
118
- def delete(self, cron_trigger: Union[str, CronWorkflows]) -> None:
133
+ async def aio_delete(self, cron_id: str) -> None:
119
134
  """
120
- Deletes a workflow cron trigger.
135
+ Asynchronously deletes a workflow cron trigger.
121
136
 
122
137
  Args:
123
- cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to delete.
138
+ cron_id (str): The cron trigger ID or CronWorkflows instance to delete.
124
139
  """
125
- self._client.rest.cron_delete(
126
- cron_trigger.metadata.id
127
- if isinstance(cron_trigger, CronWorkflows)
128
- else cron_trigger
129
- )
140
+ async with self.client() as client:
141
+ await self._wa(client).workflow_cron_delete(
142
+ tenant=self.client_config.tenant_id, cron_workflow=str(cron_id)
143
+ )
130
144
 
131
- def list(
145
+ def delete(self, cron_id: str) -> None:
146
+ return self._run_async_from_sync(self.aio_delete, cron_id)
147
+
148
+ async def aio_list(
132
149
  self,
133
150
  offset: int | None = None,
134
151
  limit: int | None = None,
135
152
  workflow_id: str | None = None,
136
- additional_metadata: list[str] | None = None,
153
+ additional_metadata: JSONSerializableMapping | None = None,
137
154
  order_by_field: CronWorkflowsOrderByField | None = None,
138
155
  order_by_direction: WorkflowRunOrderByDirection | None = None,
139
156
  ) -> CronWorkflowsList:
140
157
  """
141
- Retrieves a list of all workflow cron triggers matching the criteria.
158
+ Asynchronously retrieves a list of all workflow cron triggers matching the criteria.
142
159
 
143
160
  Args:
144
161
  offset (int | None): The offset to start the list from.
@@ -151,88 +168,30 @@ class CronClient:
151
168
  Returns:
152
169
  CronWorkflowsList: A list of cron workflows.
153
170
  """
154
- return self._client.rest.cron_list(
155
- offset=offset,
156
- limit=limit,
157
- workflow_id=workflow_id,
158
- additional_metadata=additional_metadata,
159
- order_by_field=order_by_field,
160
- order_by_direction=order_by_direction,
161
- )
162
-
163
- def get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows:
164
- """
165
- Retrieves a specific workflow cron trigger by ID.
166
-
167
- Args:
168
- cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to retrieve.
169
-
170
- Returns:
171
- CronWorkflows: The requested cron workflow instance.
172
- """
173
- return self._client.rest.cron_get(
174
- cron_trigger.metadata.id
175
- if isinstance(cron_trigger, CronWorkflows)
176
- else cron_trigger
177
- )
178
-
179
- async def aio_create(
180
- self,
181
- workflow_name: str,
182
- cron_name: str,
183
- expression: str,
184
- input: JSONSerializableMapping,
185
- additional_metadata: JSONSerializableMapping,
186
- ) -> CronWorkflows:
187
- """
188
- Asynchronously creates a new workflow cron trigger.
189
-
190
- Args:
191
- workflow_name (str): The name of the workflow to trigger.
192
- cron_name (str): The name of the cron trigger.
193
- expression (str): The cron expression defining the schedule.
194
- input (dict): The input data for the cron workflow.
195
- additional_metadata (dict[str, str]): Additional metadata associated with the cron trigger (e.g. {"key1": "value1", "key2": "value2"}).
196
-
197
- Returns:
198
- CronWorkflows: The created cron workflow instance.
199
- """
200
- validated_input = CreateCronTriggerJSONSerializableMapping(
201
- expression=expression, input=input, additional_metadata=additional_metadata
202
- )
203
-
204
- return await self._client.rest.aio_create_cron(
205
- workflow_name=workflow_name,
206
- cron_name=cron_name,
207
- expression=validated_input.expression,
208
- input=validated_input.input,
209
- additional_metadata=validated_input.additional_metadata,
210
- )
211
-
212
- async def aio_delete(self, cron_trigger: Union[str, CronWorkflows]) -> None:
213
- """
214
- Asynchronously deletes a workflow cron trigger.
215
-
216
- Args:
217
- cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to delete.
218
- """
219
- await self._client.rest.aio_delete_cron(
220
- cron_trigger.metadata.id
221
- if isinstance(cron_trigger, CronWorkflows)
222
- else cron_trigger
223
- )
171
+ async with self.client() as client:
172
+ return await self._wa(client).cron_workflow_list(
173
+ tenant=self.client_config.tenant_id,
174
+ offset=offset,
175
+ limit=limit,
176
+ workflow_id=workflow_id,
177
+ additional_metadata=maybe_additional_metadata_to_kv(
178
+ additional_metadata
179
+ ),
180
+ order_by_field=order_by_field,
181
+ order_by_direction=order_by_direction,
182
+ )
224
183
 
225
- async def aio_list(
184
+ def list(
226
185
  self,
227
186
  offset: int | None = None,
228
187
  limit: int | None = None,
229
188
  workflow_id: str | None = None,
230
- additional_metadata: List[str] | None = None,
189
+ additional_metadata: JSONSerializableMapping | None = None,
231
190
  order_by_field: CronWorkflowsOrderByField | None = None,
232
191
  order_by_direction: WorkflowRunOrderByDirection | None = None,
233
192
  ) -> CronWorkflowsList:
234
193
  """
235
- Asynchronously retrieves a list of all workflow cron triggers matching the criteria.
194
+ Synchronously retrieves a list of all workflow cron triggers matching the criteria.
236
195
 
237
196
  Args:
238
197
  offset (int | None): The offset to start the list from.
@@ -245,7 +204,8 @@ class CronClient:
245
204
  Returns:
246
205
  CronWorkflowsList: A list of cron workflows.
247
206
  """
248
- return await self._client.rest.aio_list_crons(
207
+ return self._run_async_from_sync(
208
+ self.aio_list,
249
209
  offset=offset,
250
210
  limit=limit,
251
211
  workflow_id=workflow_id,
@@ -254,19 +214,29 @@ class CronClient:
254
214
  order_by_direction=order_by_direction,
255
215
  )
256
216
 
257
- async def aio_get(self, cron_trigger: Union[str, CronWorkflows]) -> CronWorkflows:
217
+ async def aio_get(self, cron_id: str) -> CronWorkflows:
258
218
  """
259
219
  Asynchronously retrieves a specific workflow cron trigger by ID.
260
220
 
261
221
  Args:
262
- cron_trigger (Union[str, CronWorkflows]): The cron trigger ID or CronWorkflows instance to retrieve.
222
+ cron_id (str): The cron trigger ID or CronWorkflows instance to retrieve.
263
223
 
264
224
  Returns:
265
225
  CronWorkflows: The requested cron workflow instance.
266
226
  """
227
+ async with self.client() as client:
228
+ return await self._wa(client).workflow_cron_get(
229
+ tenant=self.client_config.tenant_id, cron_workflow=str(cron_id)
230
+ )
231
+
232
+ def get(self, cron_id: str) -> CronWorkflows:
233
+ """
234
+ Synchronously retrieves a specific workflow cron trigger by ID.
267
235
 
268
- return await self._client.rest.aio_get_cron(
269
- cron_trigger.metadata.id
270
- if isinstance(cron_trigger, CronWorkflows)
271
- else cron_trigger
272
- )
236
+ Args:
237
+ cron_id (str): The cron trigger ID or CronWorkflows instance to retrieve.
238
+
239
+ Returns:
240
+ CronWorkflows: The requested cron workflow instance.
241
+ """
242
+ return self._run_async_from_sync(self.aio_get, cron_id)
@@ -0,0 +1,16 @@
1
+ from hatchet_sdk.clients.rest.api.log_api import LogApi
2
+ from hatchet_sdk.clients.rest.api_client import ApiClient
3
+ from hatchet_sdk.clients.rest.models.v1_log_line_list import V1LogLineList
4
+ from hatchet_sdk.clients.v1.api_client import BaseRestClient
5
+
6
+
7
+ class LogsClient(BaseRestClient):
8
+ def _la(self, client: ApiClient) -> LogApi:
9
+ return LogApi(client)
10
+
11
+ async def aio_list(self, task_run_id: str) -> V1LogLineList:
12
+ async with self.client() as client:
13
+ return await self._la(client).v1_log_line_list(task=task_run_id)
14
+
15
+ def list(self, task_run_id: str) -> V1LogLineList:
16
+ return self._run_async_from_sync(self.aio_list, task_run_id)
@@ -0,0 +1,75 @@
1
+ from hatchet_sdk.clients.rest.api.tenant_api import TenantApi
2
+ from hatchet_sdk.clients.rest.api.workflow_api import WorkflowApi
3
+ from hatchet_sdk.clients.rest.api_client import ApiClient
4
+ from hatchet_sdk.clients.rest.models.tenant_queue_metrics import TenantQueueMetrics
5
+ from hatchet_sdk.clients.rest.models.tenant_step_run_queue_metrics import (
6
+ TenantStepRunQueueMetrics,
7
+ )
8
+ from hatchet_sdk.clients.rest.models.workflow_metrics import WorkflowMetrics
9
+ from hatchet_sdk.clients.rest.models.workflow_run_status import WorkflowRunStatus
10
+ from hatchet_sdk.clients.v1.api_client import (
11
+ BaseRestClient,
12
+ maybe_additional_metadata_to_kv,
13
+ )
14
+ from hatchet_sdk.utils.typing import JSONSerializableMapping
15
+
16
+
17
+ class MetricsClient(BaseRestClient):
18
+ def _wa(self, client: ApiClient) -> WorkflowApi:
19
+ return WorkflowApi(client)
20
+
21
+ def _ta(self, client: ApiClient) -> TenantApi:
22
+ return TenantApi(client)
23
+
24
+ async def aio_get_workflow_metrics(
25
+ self,
26
+ workflow_id: str,
27
+ status: WorkflowRunStatus | None = None,
28
+ group_key: str | None = None,
29
+ ) -> WorkflowMetrics:
30
+ async with self.client() as client:
31
+ return await self._wa(client).workflow_get_metrics(
32
+ workflow=workflow_id, status=status, group_key=group_key
33
+ )
34
+
35
+ def get_workflow_metrics(
36
+ self,
37
+ workflow_id: str,
38
+ status: WorkflowRunStatus | None = None,
39
+ group_key: str | None = None,
40
+ ) -> WorkflowMetrics:
41
+ return self._run_async_from_sync(
42
+ self.aio_get_workflow_metrics, workflow_id, status, group_key
43
+ )
44
+
45
+ async def aio_get_queue_metrics(
46
+ self,
47
+ workflow_ids: list[str] | None = None,
48
+ additional_metadata: JSONSerializableMapping | None = None,
49
+ ) -> TenantQueueMetrics:
50
+ async with self.client() as client:
51
+ return await self._wa(client).tenant_get_queue_metrics(
52
+ tenant=self.client_config.tenant_id,
53
+ workflows=workflow_ids,
54
+ additional_metadata=maybe_additional_metadata_to_kv(
55
+ additional_metadata
56
+ ),
57
+ )
58
+
59
+ def get_queue_metrics(
60
+ self,
61
+ workflow_ids: list[str] | None = None,
62
+ additional_metadata: JSONSerializableMapping | None = None,
63
+ ) -> TenantQueueMetrics:
64
+ return self._run_async_from_sync(
65
+ self.aio_get_queue_metrics, workflow_ids, additional_metadata
66
+ )
67
+
68
+ async def aio_get_task_metrics(self) -> TenantStepRunQueueMetrics:
69
+ async with self.client() as client:
70
+ return await self._ta(client).tenant_get_step_run_queue_metrics(
71
+ tenant=self.client_config.tenant_id
72
+ )
73
+
74
+ def get_task_metrics(self) -> TenantStepRunQueueMetrics:
75
+ return self._run_async_from_sync(self.aio_get_task_metrics)
@@ -0,0 +1,45 @@
1
+ import asyncio
2
+
3
+ from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
4
+ from hatchet_sdk.clients.v1.api_client import BaseRestClient
5
+ from hatchet_sdk.connection import new_conn
6
+ from hatchet_sdk.contracts import workflows_pb2 as v0_workflow_protos
7
+ from hatchet_sdk.contracts.v1 import workflows_pb2 as workflow_protos
8
+ from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub
9
+ from hatchet_sdk.metadata import get_metadata
10
+ from hatchet_sdk.rate_limit import RateLimitDuration
11
+ from hatchet_sdk.utils.proto_enums import convert_python_enum_to_proto
12
+
13
+
14
+ class RateLimitsClient(BaseRestClient):
15
+ @tenacity_retry
16
+ def put(
17
+ self,
18
+ key: str,
19
+ limit: int,
20
+ duration: RateLimitDuration = RateLimitDuration.SECOND,
21
+ ) -> None:
22
+ duration_proto = convert_python_enum_to_proto(
23
+ duration, workflow_protos.RateLimitDuration
24
+ )
25
+
26
+ conn = new_conn(self.client_config, False)
27
+ client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
28
+
29
+ client.PutRateLimit(
30
+ v0_workflow_protos.PutRateLimitRequest(
31
+ key=key,
32
+ limit=limit,
33
+ duration=duration_proto, # type: ignore[arg-type]
34
+ ),
35
+ metadata=get_metadata(self.client_config.token),
36
+ )
37
+
38
+ @tenacity_retry
39
+ async def aio_put(
40
+ self,
41
+ key: str,
42
+ limit: int,
43
+ duration: RateLimitDuration = RateLimitDuration.SECOND,
44
+ ) -> None:
45
+ await asyncio.to_thread(self.put, key, limit, duration)
@@ -0,0 +1,221 @@
1
+ from datetime import datetime, timedelta
2
+ from typing import Literal, overload
3
+
4
+ from pydantic import BaseModel, model_validator
5
+
6
+ from hatchet_sdk.clients.rest.api.task_api import TaskApi
7
+ from hatchet_sdk.clients.rest.api.workflow_runs_api import WorkflowRunsApi
8
+ from hatchet_sdk.clients.rest.api_client import ApiClient
9
+ from hatchet_sdk.clients.rest.models.v1_cancel_task_request import V1CancelTaskRequest
10
+ from hatchet_sdk.clients.rest.models.v1_replay_task_request import V1ReplayTaskRequest
11
+ from hatchet_sdk.clients.rest.models.v1_task_filter import V1TaskFilter
12
+ from hatchet_sdk.clients.rest.models.v1_task_status import V1TaskStatus
13
+ from hatchet_sdk.clients.rest.models.v1_task_summary_list import V1TaskSummaryList
14
+ from hatchet_sdk.clients.rest.models.v1_trigger_workflow_run_request import (
15
+ V1TriggerWorkflowRunRequest,
16
+ )
17
+ from hatchet_sdk.clients.rest.models.v1_workflow_run_details import V1WorkflowRunDetails
18
+ from hatchet_sdk.clients.v1.api_client import (
19
+ BaseRestClient,
20
+ maybe_additional_metadata_to_kv,
21
+ )
22
+ from hatchet_sdk.utils.typing import JSONSerializableMapping
23
+
24
+
25
+ class RunFilter(BaseModel):
26
+ since: datetime
27
+ until: datetime | None = None
28
+ statuses: list[V1TaskStatus] | None = None
29
+ workflow_ids: list[str] | None = None
30
+ additional_metadata: dict[str, str] | None = None
31
+
32
+
33
+ class BulkCancelReplayOpts(BaseModel):
34
+ ids: list[str] | None = None
35
+ filters: RunFilter | None = None
36
+
37
+ @model_validator(mode="after")
38
+ def validate_model(self) -> "BulkCancelReplayOpts":
39
+ if not self.ids and not self.filters:
40
+ raise ValueError("ids or filters must be set")
41
+
42
+ if self.ids and self.filters:
43
+ raise ValueError("ids and filters cannot both be set")
44
+
45
+ return self
46
+
47
+ @property
48
+ def v1_task_filter(self) -> V1TaskFilter | None:
49
+ if not self.filters:
50
+ return None
51
+
52
+ return V1TaskFilter(
53
+ since=self.filters.since,
54
+ until=self.filters.until,
55
+ statuses=self.filters.statuses,
56
+ workflowIds=self.filters.workflow_ids,
57
+ additionalMetadata=maybe_additional_metadata_to_kv(
58
+ self.filters.additional_metadata
59
+ ),
60
+ )
61
+
62
+ @overload
63
+ def to_request(self, request_type: Literal["replay"]) -> V1ReplayTaskRequest: ...
64
+
65
+ @overload
66
+ def to_request(self, request_type: Literal["cancel"]) -> V1CancelTaskRequest: ...
67
+
68
+ def to_request(
69
+ self, request_type: Literal["replay", "cancel"]
70
+ ) -> V1ReplayTaskRequest | V1CancelTaskRequest:
71
+ if request_type == "replay":
72
+ return V1ReplayTaskRequest(
73
+ externalIds=self.ids,
74
+ filter=self.v1_task_filter,
75
+ )
76
+
77
+ if request_type == "cancel":
78
+ return V1CancelTaskRequest(
79
+ externalIds=self.ids,
80
+ filter=self.v1_task_filter,
81
+ )
82
+
83
+
84
+ class RunsClient(BaseRestClient):
85
+ def _wra(self, client: ApiClient) -> WorkflowRunsApi:
86
+ return WorkflowRunsApi(client)
87
+
88
+ def _ta(self, client: ApiClient) -> TaskApi:
89
+ return TaskApi(client)
90
+
91
+ async def aio_get(self, workflow_run_id: str) -> V1WorkflowRunDetails:
92
+ async with self.client() as client:
93
+ return await self._wra(client).v1_workflow_run_get(str(workflow_run_id))
94
+
95
+ def get(self, workflow_run_id: str) -> V1WorkflowRunDetails:
96
+ return self._run_async_from_sync(self.aio_get, workflow_run_id)
97
+
98
+ async def aio_list(
99
+ self,
100
+ since: datetime = datetime.now() - timedelta(hours=1),
101
+ only_tasks: bool = False,
102
+ offset: int | None = None,
103
+ limit: int | None = None,
104
+ statuses: list[V1TaskStatus] | None = None,
105
+ until: datetime | None = None,
106
+ additional_metadata: dict[str, str] | None = None,
107
+ workflow_ids: list[str] | None = None,
108
+ worker_id: str | None = None,
109
+ parent_task_external_id: str | None = None,
110
+ ) -> V1TaskSummaryList:
111
+ async with self.client() as client:
112
+ return await self._wra(client).v1_workflow_run_list(
113
+ tenant=self.client_config.tenant_id,
114
+ since=since,
115
+ only_tasks=only_tasks,
116
+ offset=offset,
117
+ limit=limit,
118
+ statuses=statuses,
119
+ until=until,
120
+ additional_metadata=maybe_additional_metadata_to_kv(
121
+ additional_metadata
122
+ ),
123
+ workflow_ids=workflow_ids,
124
+ worker_id=worker_id,
125
+ parent_task_external_id=parent_task_external_id,
126
+ )
127
+
128
+ def list(
129
+ self,
130
+ since: datetime = datetime.now() - timedelta(hours=1),
131
+ only_tasks: bool = False,
132
+ offset: int | None = None,
133
+ limit: int | None = None,
134
+ statuses: list[V1TaskStatus] | None = None,
135
+ until: datetime | None = None,
136
+ additional_metadata: dict[str, str] | None = None,
137
+ workflow_ids: list[str] | None = None,
138
+ worker_id: str | None = None,
139
+ parent_task_external_id: str | None = None,
140
+ ) -> V1TaskSummaryList:
141
+ return self._run_async_from_sync(
142
+ self.aio_list,
143
+ since=since,
144
+ only_tasks=only_tasks,
145
+ offset=offset,
146
+ limit=limit,
147
+ statuses=statuses,
148
+ until=until,
149
+ additional_metadata=additional_metadata,
150
+ workflow_ids=workflow_ids,
151
+ worker_id=worker_id,
152
+ parent_task_external_id=parent_task_external_id,
153
+ )
154
+
155
+ async def aio_create(
156
+ self,
157
+ workflow_name: str,
158
+ input: JSONSerializableMapping,
159
+ additional_metadata: JSONSerializableMapping = {},
160
+ ) -> V1WorkflowRunDetails:
161
+ async with self.client() as client:
162
+ return await self._wra(client).v1_workflow_run_create(
163
+ tenant=self.client_config.tenant_id,
164
+ v1_trigger_workflow_run_request=V1TriggerWorkflowRunRequest(
165
+ workflowName=workflow_name,
166
+ input=dict(input),
167
+ additionalMetadata=dict(additional_metadata),
168
+ ),
169
+ )
170
+
171
+ def create(
172
+ self,
173
+ workflow_name: str,
174
+ input: JSONSerializableMapping,
175
+ additional_metadata: JSONSerializableMapping = {},
176
+ ) -> V1WorkflowRunDetails:
177
+ return self._run_async_from_sync(
178
+ self.aio_create, workflow_name, input, additional_metadata
179
+ )
180
+
181
+ async def aio_replay(self, run_id: str) -> None:
182
+ await self.aio_bulk_replay(opts=BulkCancelReplayOpts(ids=[run_id]))
183
+
184
+ def replay(self, run_id: str) -> None:
185
+ return self._run_async_from_sync(self.aio_replay, run_id)
186
+
187
+ async def aio_bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
188
+ async with self.client() as client:
189
+ await self._ta(client).v1_task_replay(
190
+ tenant=self.client_config.tenant_id,
191
+ v1_replay_task_request=opts.to_request("replay"),
192
+ )
193
+
194
+ def bulk_replay(self, opts: BulkCancelReplayOpts) -> None:
195
+ return self._run_async_from_sync(self.aio_bulk_replay, opts)
196
+
197
+ async def aio_cancel(self, run_id: str) -> None:
198
+ await self.aio_bulk_cancel(opts=BulkCancelReplayOpts(ids=[run_id]))
199
+
200
+ def cancel(self, run_id: str) -> None:
201
+ return self._run_async_from_sync(self.aio_cancel, run_id)
202
+
203
+ async def aio_bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
204
+ async with self.client() as client:
205
+ await self._ta(client).v1_task_cancel(
206
+ tenant=self.client_config.tenant_id,
207
+ v1_cancel_task_request=opts.to_request("cancel"),
208
+ )
209
+
210
+ def bulk_cancel(self, opts: BulkCancelReplayOpts) -> None:
211
+ return self._run_async_from_sync(self.aio_bulk_cancel, opts)
212
+
213
+ async def aio_get_result(self, run_id: str) -> JSONSerializableMapping:
214
+ details = await self.aio_get(run_id)
215
+
216
+ return details.run.output
217
+
218
+ def get_result(self, run_id: str) -> JSONSerializableMapping:
219
+ details = self.get(run_id)
220
+
221
+ return details.run.output