hatchet-sdk 1.2.5__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (60) hide show
  1. hatchet_sdk/__init__.py +7 -5
  2. hatchet_sdk/client.py +14 -6
  3. hatchet_sdk/clients/admin.py +57 -15
  4. hatchet_sdk/clients/dispatcher/action_listener.py +2 -2
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +20 -7
  6. hatchet_sdk/clients/event_ts.py +25 -5
  7. hatchet_sdk/clients/listeners/durable_event_listener.py +125 -0
  8. hatchet_sdk/clients/listeners/pooled_listener.py +255 -0
  9. hatchet_sdk/clients/listeners/workflow_listener.py +62 -0
  10. hatchet_sdk/clients/rest/api/api_token_api.py +24 -24
  11. hatchet_sdk/clients/rest/api/default_api.py +64 -64
  12. hatchet_sdk/clients/rest/api/event_api.py +64 -64
  13. hatchet_sdk/clients/rest/api/github_api.py +8 -8
  14. hatchet_sdk/clients/rest/api/healthcheck_api.py +16 -16
  15. hatchet_sdk/clients/rest/api/log_api.py +16 -16
  16. hatchet_sdk/clients/rest/api/metadata_api.py +24 -24
  17. hatchet_sdk/clients/rest/api/rate_limits_api.py +8 -8
  18. hatchet_sdk/clients/rest/api/slack_api.py +16 -16
  19. hatchet_sdk/clients/rest/api/sns_api.py +24 -24
  20. hatchet_sdk/clients/rest/api/step_run_api.py +56 -56
  21. hatchet_sdk/clients/rest/api/task_api.py +56 -56
  22. hatchet_sdk/clients/rest/api/tenant_api.py +128 -128
  23. hatchet_sdk/clients/rest/api/user_api.py +96 -96
  24. hatchet_sdk/clients/rest/api/worker_api.py +24 -24
  25. hatchet_sdk/clients/rest/api/workflow_api.py +144 -144
  26. hatchet_sdk/clients/rest/api/workflow_run_api.py +48 -48
  27. hatchet_sdk/clients/rest/api/workflow_runs_api.py +40 -40
  28. hatchet_sdk/clients/rest/api_client.py +5 -8
  29. hatchet_sdk/clients/rest/configuration.py +7 -3
  30. hatchet_sdk/clients/rest/models/tenant_step_run_queue_metrics.py +2 -2
  31. hatchet_sdk/clients/rest/models/v1_task_summary.py +5 -0
  32. hatchet_sdk/clients/rest/models/workflow_runs_metrics.py +5 -1
  33. hatchet_sdk/clients/rest/rest.py +160 -111
  34. hatchet_sdk/clients/v1/api_client.py +2 -2
  35. hatchet_sdk/context/context.py +22 -21
  36. hatchet_sdk/features/cron.py +41 -40
  37. hatchet_sdk/features/logs.py +7 -6
  38. hatchet_sdk/features/metrics.py +19 -18
  39. hatchet_sdk/features/runs.py +88 -68
  40. hatchet_sdk/features/scheduled.py +42 -42
  41. hatchet_sdk/features/workers.py +17 -16
  42. hatchet_sdk/features/workflows.py +15 -14
  43. hatchet_sdk/hatchet.py +1 -1
  44. hatchet_sdk/runnables/standalone.py +12 -9
  45. hatchet_sdk/runnables/task.py +66 -2
  46. hatchet_sdk/runnables/types.py +8 -0
  47. hatchet_sdk/runnables/workflow.py +48 -136
  48. hatchet_sdk/waits.py +8 -8
  49. hatchet_sdk/worker/runner/run_loop_manager.py +4 -4
  50. hatchet_sdk/worker/runner/runner.py +22 -11
  51. hatchet_sdk/worker/worker.py +29 -25
  52. hatchet_sdk/workflow_run.py +55 -9
  53. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/METADATA +1 -1
  54. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/RECORD +57 -57
  55. hatchet_sdk/clients/durable_event_listener.py +0 -329
  56. hatchet_sdk/clients/workflow_listener.py +0 -288
  57. hatchet_sdk/utils/aio.py +0 -43
  58. /hatchet_sdk/clients/{run_event_listener.py → listeners/run_event_listener.py} +0 -0
  59. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/WHEEL +0 -0
  60. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,255 @@
1
+ import asyncio
2
+ from abc import ABC, abstractmethod
3
+ from collections.abc import AsyncIterator
4
+ from typing import Generic, Literal, TypeVar
5
+
6
+ import grpc
7
+ import grpc.aio
8
+ from grpc._cython import cygrpc # type: ignore[attr-defined]
9
+
10
+ from hatchet_sdk.clients.event_ts import ThreadSafeEvent, read_with_interrupt
11
+ from hatchet_sdk.config import ClientConfig
12
+ from hatchet_sdk.logger import logger
13
+ from hatchet_sdk.metadata import get_metadata
14
+
15
+ DEFAULT_LISTENER_RETRY_INTERVAL = 3 # seconds
16
+ DEFAULT_LISTENER_RETRY_COUNT = 5
17
+ DEFAULT_LISTENER_INTERRUPT_INTERVAL = 1800 # 30 minutes
18
+
19
+ R = TypeVar("R")
20
+ T = TypeVar("T")
21
+ L = TypeVar("L")
22
+
23
+ SentinelValue = Literal["STOP"]
24
+ SENTINEL_VALUE: SentinelValue = "STOP"
25
+
26
+
27
+ TRequest = TypeVar("TRequest")
28
+ TResponse = TypeVar("TResponse")
29
+
30
+
31
+ class Subscription(Generic[T]):
32
+ def __init__(self, id: int) -> None:
33
+ self.id = id
34
+ self.queue: asyncio.Queue[T | SentinelValue] = asyncio.Queue()
35
+
36
+ async def __aiter__(self) -> "Subscription[T]":
37
+ return self
38
+
39
+ async def __anext__(self) -> T | SentinelValue:
40
+ return await self.queue.get()
41
+
42
+ async def get(self) -> T:
43
+ event = await self.queue.get()
44
+
45
+ if event == "STOP":
46
+ raise StopAsyncIteration
47
+
48
+ return event
49
+
50
+ async def put(self, item: T) -> None:
51
+ await self.queue.put(item)
52
+
53
+ async def close(self) -> None:
54
+ await self.queue.put("STOP")
55
+
56
+
57
+ class PooledListener(Generic[R, T, L], ABC):
58
+ def __init__(self, config: ClientConfig):
59
+ self.token = config.token
60
+ self.config = config
61
+
62
+ self.from_subscriptions: dict[int, str] = {}
63
+ self.to_subscriptions: dict[str, list[int]] = {}
64
+
65
+ self.subscription_counter: int = 0
66
+ self.subscription_counter_lock: asyncio.Lock = asyncio.Lock()
67
+
68
+ self.requests: asyncio.Queue[R | int] = asyncio.Queue()
69
+
70
+ self.listener: grpc.aio.UnaryStreamCall[R, T] | None = None
71
+ self.listener_task: asyncio.Task[None] | None = None
72
+
73
+ self.curr_requester: int = 0
74
+
75
+ self.events: dict[int, Subscription[T]] = {}
76
+
77
+ self.interrupter: asyncio.Task[None] | None = None
78
+
79
+ ## IMPORTANT: This needs to be created lazily so we don't require
80
+ ## an event loop to instantiate the client.
81
+ self.client: L | None = None
82
+
83
+ async def _interrupter(self) -> None:
84
+ """
85
+ _interrupter runs in a separate thread and interrupts the listener according to a configurable duration.
86
+ """
87
+ await asyncio.sleep(DEFAULT_LISTENER_INTERRUPT_INTERVAL)
88
+
89
+ if self.interrupt is not None:
90
+ self.interrupt.set()
91
+
92
+ @abstractmethod
93
+ def generate_key(self, response: T) -> str:
94
+ pass
95
+
96
+ async def _init_producer(self) -> None:
97
+ try:
98
+ if not self.listener:
99
+ while True:
100
+ try:
101
+ self.listener = await self._retry_subscribe()
102
+
103
+ logger.debug("Listener connected.")
104
+
105
+ # spawn an interrupter task
106
+ if self.interrupter is not None and not self.interrupter.done():
107
+ self.interrupter.cancel()
108
+
109
+ self.interrupter = asyncio.create_task(self._interrupter())
110
+
111
+ while True:
112
+ self.interrupt = ThreadSafeEvent()
113
+ if self.listener is None:
114
+ continue
115
+
116
+ t = asyncio.create_task(
117
+ read_with_interrupt(
118
+ self.listener, self.interrupt, self.generate_key
119
+ )
120
+ )
121
+ await self.interrupt.wait()
122
+
123
+ if not t.done():
124
+ logger.warning(
125
+ "Interrupted read_with_interrupt task of listener"
126
+ )
127
+
128
+ t.cancel()
129
+ self.listener.cancel()
130
+
131
+ await asyncio.sleep(DEFAULT_LISTENER_RETRY_INTERVAL)
132
+ break
133
+
134
+ event, key = t.result()
135
+
136
+ if event is cygrpc.EOF:
137
+ break
138
+
139
+ subscriptions = self.to_subscriptions.get(key, [])
140
+
141
+ for subscription_id in subscriptions:
142
+ await self.events[subscription_id].put(event)
143
+
144
+ except grpc.RpcError as e:
145
+ logger.debug(f"grpc error in listener: {e}")
146
+ await asyncio.sleep(DEFAULT_LISTENER_RETRY_INTERVAL)
147
+ continue
148
+
149
+ except Exception as e:
150
+ logger.error(f"Error in listener: {e}")
151
+
152
+ self.listener = None
153
+
154
+ # close all subscriptions
155
+ for subscription_id in self.events:
156
+ await self.events[subscription_id].close()
157
+
158
+ raise e
159
+
160
+ @abstractmethod
161
+ def create_request_body(self, item: str) -> R:
162
+ pass
163
+
164
+ async def _request(self) -> AsyncIterator[R]:
165
+ self.curr_requester = self.curr_requester + 1
166
+
167
+ to_subscribe_to = set(self.from_subscriptions.values())
168
+
169
+ for item in to_subscribe_to:
170
+ yield self.create_request_body(item)
171
+
172
+ while True:
173
+ request = await self.requests.get()
174
+
175
+ # if the request is an int which matches the current requester, then we should stop
176
+ if request == self.curr_requester:
177
+ break
178
+
179
+ # if we've gotten an int that doesn't match the current requester, then we should ignore it
180
+ if isinstance(request, int):
181
+ continue
182
+
183
+ yield request
184
+
185
+ self.requests.task_done()
186
+
187
+ def cleanup_subscription(self, subscription_id: int) -> None:
188
+ id = self.from_subscriptions[subscription_id]
189
+
190
+ if id in self.to_subscriptions:
191
+ self.to_subscriptions[id].remove(subscription_id)
192
+
193
+ del self.from_subscriptions[subscription_id]
194
+ del self.events[subscription_id]
195
+
196
+ async def subscribe(self, id: str) -> T:
197
+ subscription_id: int | None = None
198
+
199
+ try:
200
+ async with self.subscription_counter_lock:
201
+ self.subscription_counter += 1
202
+ subscription_id = self.subscription_counter
203
+
204
+ self.from_subscriptions[subscription_id] = id
205
+
206
+ if id not in self.to_subscriptions:
207
+ self.to_subscriptions[id] = [subscription_id]
208
+ else:
209
+ self.to_subscriptions[id].append(subscription_id)
210
+
211
+ self.events[subscription_id] = Subscription(subscription_id)
212
+
213
+ await self.requests.put(self.create_request_body(id))
214
+
215
+ if not self.listener_task or self.listener_task.done():
216
+ self.listener_task = asyncio.create_task(self._init_producer())
217
+
218
+ return await self.events[subscription_id].get()
219
+ except asyncio.CancelledError:
220
+ raise
221
+ finally:
222
+ if subscription_id:
223
+ self.cleanup_subscription(subscription_id)
224
+
225
+ @abstractmethod
226
+ async def create_subscription(
227
+ self, request: AsyncIterator[R], metadata: tuple[tuple[str, str]]
228
+ ) -> grpc.aio.UnaryStreamCall[R, T]:
229
+ pass
230
+
231
+ async def _retry_subscribe(
232
+ self,
233
+ ) -> grpc.aio.UnaryStreamCall[R, T]:
234
+ retries = 0
235
+ while retries < DEFAULT_LISTENER_RETRY_COUNT:
236
+ try:
237
+ if retries > 0:
238
+ await asyncio.sleep(DEFAULT_LISTENER_RETRY_INTERVAL)
239
+
240
+ # signal previous async iterator to stop
241
+ if self.curr_requester != 0:
242
+ self.requests.put_nowait(self.curr_requester)
243
+
244
+ return await self.create_subscription(
245
+ self._request(),
246
+ metadata=get_metadata(self.token),
247
+ )
248
+
249
+ except grpc.RpcError as e:
250
+ if e.code() == grpc.StatusCode.UNAVAILABLE:
251
+ retries = retries + 1
252
+ else:
253
+ raise ValueError(f"gRPC error: {e}")
254
+
255
+ raise ValueError("Failed to connect to listener")
@@ -0,0 +1,62 @@
1
+ import json
2
+ from typing import Any, AsyncIterator, cast
3
+
4
+ import grpc
5
+ import grpc.aio
6
+
7
+ from hatchet_sdk.clients.listeners.pooled_listener import PooledListener
8
+ from hatchet_sdk.connection import new_conn
9
+ from hatchet_sdk.contracts.dispatcher_pb2 import (
10
+ SubscribeToWorkflowRunsRequest,
11
+ WorkflowRunEvent,
12
+ )
13
+ from hatchet_sdk.contracts.dispatcher_pb2_grpc import DispatcherStub
14
+
15
+ DEDUPE_MESSAGE = "DUPLICATE_WORKFLOW_RUN"
16
+
17
+
18
+ class PooledWorkflowRunListener(
19
+ PooledListener[SubscribeToWorkflowRunsRequest, WorkflowRunEvent, DispatcherStub]
20
+ ):
21
+ def create_request_body(self, item: str) -> SubscribeToWorkflowRunsRequest:
22
+ return SubscribeToWorkflowRunsRequest(
23
+ workflowRunId=item,
24
+ )
25
+
26
+ def generate_key(self, response: WorkflowRunEvent) -> str:
27
+ return response.workflowRunId
28
+
29
+ async def aio_result(self, id: str) -> dict[str, Any]:
30
+ from hatchet_sdk.clients.admin import DedupeViolationErr
31
+
32
+ event = await self.subscribe(id)
33
+ errors = [result.error for result in event.results if result.error]
34
+
35
+ if errors:
36
+ if DEDUPE_MESSAGE in errors[0]:
37
+ raise DedupeViolationErr(errors[0])
38
+ else:
39
+ raise Exception(f"Workflow Errors: {errors}")
40
+
41
+ return {
42
+ result.stepReadableId: json.loads(result.output)
43
+ for result in event.results
44
+ if result.output
45
+ }
46
+
47
+ async def create_subscription(
48
+ self,
49
+ request: AsyncIterator[SubscribeToWorkflowRunsRequest],
50
+ metadata: tuple[tuple[str, str]],
51
+ ) -> grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent]:
52
+ if self.client is None:
53
+ conn = new_conn(self.config, True)
54
+ self.client = DispatcherStub(conn)
55
+
56
+ return cast(
57
+ grpc.aio.UnaryStreamCall[SubscribeToWorkflowRunsRequest, WorkflowRunEvent],
58
+ self.client.SubscribeToWorkflowRuns(
59
+ request, # type: ignore[arg-type]
60
+ metadata=metadata,
61
+ ),
62
+ )
@@ -44,7 +44,7 @@ class APITokenApi:
44
44
  self.api_client = api_client
45
45
 
46
46
  @validate_call
47
- async def api_token_create(
47
+ def api_token_create(
48
48
  self,
49
49
  tenant: Annotated[
50
50
  str,
@@ -109,17 +109,17 @@ class APITokenApi:
109
109
  "400": "APIErrors",
110
110
  "403": "APIErrors",
111
111
  }
112
- response_data = await self.api_client.call_api(
112
+ response_data = self.api_client.call_api(
113
113
  *_param, _request_timeout=_request_timeout
114
114
  )
115
- await response_data.read()
115
+ response_data.read()
116
116
  return self.api_client.response_deserialize(
117
117
  response_data=response_data,
118
118
  response_types_map=_response_types_map,
119
119
  ).data
120
120
 
121
121
  @validate_call
122
- async def api_token_create_with_http_info(
122
+ def api_token_create_with_http_info(
123
123
  self,
124
124
  tenant: Annotated[
125
125
  str,
@@ -184,17 +184,17 @@ class APITokenApi:
184
184
  "400": "APIErrors",
185
185
  "403": "APIErrors",
186
186
  }
187
- response_data = await self.api_client.call_api(
187
+ response_data = self.api_client.call_api(
188
188
  *_param, _request_timeout=_request_timeout
189
189
  )
190
- await response_data.read()
190
+ response_data.read()
191
191
  return self.api_client.response_deserialize(
192
192
  response_data=response_data,
193
193
  response_types_map=_response_types_map,
194
194
  )
195
195
 
196
196
  @validate_call
197
- async def api_token_create_without_preload_content(
197
+ def api_token_create_without_preload_content(
198
198
  self,
199
199
  tenant: Annotated[
200
200
  str,
@@ -259,7 +259,7 @@ class APITokenApi:
259
259
  "400": "APIErrors",
260
260
  "403": "APIErrors",
261
261
  }
262
- response_data = await self.api_client.call_api(
262
+ response_data = self.api_client.call_api(
263
263
  *_param, _request_timeout=_request_timeout
264
264
  )
265
265
  return response_data.response
@@ -332,7 +332,7 @@ class APITokenApi:
332
332
  )
333
333
 
334
334
  @validate_call
335
- async def api_token_list(
335
+ def api_token_list(
336
336
  self,
337
337
  tenant: Annotated[
338
338
  str,
@@ -393,17 +393,17 @@ class APITokenApi:
393
393
  "400": "APIErrors",
394
394
  "403": "APIErrors",
395
395
  }
396
- response_data = await self.api_client.call_api(
396
+ response_data = self.api_client.call_api(
397
397
  *_param, _request_timeout=_request_timeout
398
398
  )
399
- await response_data.read()
399
+ response_data.read()
400
400
  return self.api_client.response_deserialize(
401
401
  response_data=response_data,
402
402
  response_types_map=_response_types_map,
403
403
  ).data
404
404
 
405
405
  @validate_call
406
- async def api_token_list_with_http_info(
406
+ def api_token_list_with_http_info(
407
407
  self,
408
408
  tenant: Annotated[
409
409
  str,
@@ -464,17 +464,17 @@ class APITokenApi:
464
464
  "400": "APIErrors",
465
465
  "403": "APIErrors",
466
466
  }
467
- response_data = await self.api_client.call_api(
467
+ response_data = self.api_client.call_api(
468
468
  *_param, _request_timeout=_request_timeout
469
469
  )
470
- await response_data.read()
470
+ response_data.read()
471
471
  return self.api_client.response_deserialize(
472
472
  response_data=response_data,
473
473
  response_types_map=_response_types_map,
474
474
  )
475
475
 
476
476
  @validate_call
477
- async def api_token_list_without_preload_content(
477
+ def api_token_list_without_preload_content(
478
478
  self,
479
479
  tenant: Annotated[
480
480
  str,
@@ -535,7 +535,7 @@ class APITokenApi:
535
535
  "400": "APIErrors",
536
536
  "403": "APIErrors",
537
537
  }
538
- response_data = await self.api_client.call_api(
538
+ response_data = self.api_client.call_api(
539
539
  *_param, _request_timeout=_request_timeout
540
540
  )
541
541
  return response_data.response
@@ -595,7 +595,7 @@ class APITokenApi:
595
595
  )
596
596
 
597
597
  @validate_call
598
- async def api_token_update_revoke(
598
+ def api_token_update_revoke(
599
599
  self,
600
600
  api_token: Annotated[
601
601
  str,
@@ -656,17 +656,17 @@ class APITokenApi:
656
656
  "400": "APIErrors",
657
657
  "403": "APIErrors",
658
658
  }
659
- response_data = await self.api_client.call_api(
659
+ response_data = self.api_client.call_api(
660
660
  *_param, _request_timeout=_request_timeout
661
661
  )
662
- await response_data.read()
662
+ response_data.read()
663
663
  return self.api_client.response_deserialize(
664
664
  response_data=response_data,
665
665
  response_types_map=_response_types_map,
666
666
  ).data
667
667
 
668
668
  @validate_call
669
- async def api_token_update_revoke_with_http_info(
669
+ def api_token_update_revoke_with_http_info(
670
670
  self,
671
671
  api_token: Annotated[
672
672
  str,
@@ -727,17 +727,17 @@ class APITokenApi:
727
727
  "400": "APIErrors",
728
728
  "403": "APIErrors",
729
729
  }
730
- response_data = await self.api_client.call_api(
730
+ response_data = self.api_client.call_api(
731
731
  *_param, _request_timeout=_request_timeout
732
732
  )
733
- await response_data.read()
733
+ response_data.read()
734
734
  return self.api_client.response_deserialize(
735
735
  response_data=response_data,
736
736
  response_types_map=_response_types_map,
737
737
  )
738
738
 
739
739
  @validate_call
740
- async def api_token_update_revoke_without_preload_content(
740
+ def api_token_update_revoke_without_preload_content(
741
741
  self,
742
742
  api_token: Annotated[
743
743
  str,
@@ -798,7 +798,7 @@ class APITokenApi:
798
798
  "400": "APIErrors",
799
799
  "403": "APIErrors",
800
800
  }
801
- response_data = await self.api_client.call_api(
801
+ response_data = self.api_client.call_api(
802
802
  *_param, _request_timeout=_request_timeout
803
803
  )
804
804
  return response_data.response