hatchet-sdk 1.2.5__py3-none-any.whl → 1.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (60) hide show
  1. hatchet_sdk/__init__.py +7 -5
  2. hatchet_sdk/client.py +14 -6
  3. hatchet_sdk/clients/admin.py +57 -15
  4. hatchet_sdk/clients/dispatcher/action_listener.py +2 -2
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +20 -7
  6. hatchet_sdk/clients/event_ts.py +25 -5
  7. hatchet_sdk/clients/listeners/durable_event_listener.py +125 -0
  8. hatchet_sdk/clients/listeners/pooled_listener.py +255 -0
  9. hatchet_sdk/clients/listeners/workflow_listener.py +62 -0
  10. hatchet_sdk/clients/rest/api/api_token_api.py +24 -24
  11. hatchet_sdk/clients/rest/api/default_api.py +64 -64
  12. hatchet_sdk/clients/rest/api/event_api.py +64 -64
  13. hatchet_sdk/clients/rest/api/github_api.py +8 -8
  14. hatchet_sdk/clients/rest/api/healthcheck_api.py +16 -16
  15. hatchet_sdk/clients/rest/api/log_api.py +16 -16
  16. hatchet_sdk/clients/rest/api/metadata_api.py +24 -24
  17. hatchet_sdk/clients/rest/api/rate_limits_api.py +8 -8
  18. hatchet_sdk/clients/rest/api/slack_api.py +16 -16
  19. hatchet_sdk/clients/rest/api/sns_api.py +24 -24
  20. hatchet_sdk/clients/rest/api/step_run_api.py +56 -56
  21. hatchet_sdk/clients/rest/api/task_api.py +56 -56
  22. hatchet_sdk/clients/rest/api/tenant_api.py +128 -128
  23. hatchet_sdk/clients/rest/api/user_api.py +96 -96
  24. hatchet_sdk/clients/rest/api/worker_api.py +24 -24
  25. hatchet_sdk/clients/rest/api/workflow_api.py +144 -144
  26. hatchet_sdk/clients/rest/api/workflow_run_api.py +48 -48
  27. hatchet_sdk/clients/rest/api/workflow_runs_api.py +40 -40
  28. hatchet_sdk/clients/rest/api_client.py +5 -8
  29. hatchet_sdk/clients/rest/configuration.py +7 -3
  30. hatchet_sdk/clients/rest/models/tenant_step_run_queue_metrics.py +2 -2
  31. hatchet_sdk/clients/rest/models/v1_task_summary.py +5 -0
  32. hatchet_sdk/clients/rest/models/workflow_runs_metrics.py +5 -1
  33. hatchet_sdk/clients/rest/rest.py +160 -111
  34. hatchet_sdk/clients/v1/api_client.py +2 -2
  35. hatchet_sdk/context/context.py +22 -21
  36. hatchet_sdk/features/cron.py +41 -40
  37. hatchet_sdk/features/logs.py +7 -6
  38. hatchet_sdk/features/metrics.py +19 -18
  39. hatchet_sdk/features/runs.py +88 -68
  40. hatchet_sdk/features/scheduled.py +42 -42
  41. hatchet_sdk/features/workers.py +17 -16
  42. hatchet_sdk/features/workflows.py +15 -14
  43. hatchet_sdk/hatchet.py +1 -1
  44. hatchet_sdk/runnables/standalone.py +12 -9
  45. hatchet_sdk/runnables/task.py +66 -2
  46. hatchet_sdk/runnables/types.py +8 -0
  47. hatchet_sdk/runnables/workflow.py +48 -136
  48. hatchet_sdk/waits.py +8 -8
  49. hatchet_sdk/worker/runner/run_loop_manager.py +4 -4
  50. hatchet_sdk/worker/runner/runner.py +22 -11
  51. hatchet_sdk/worker/worker.py +29 -25
  52. hatchet_sdk/workflow_run.py +55 -9
  53. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/METADATA +1 -1
  54. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/RECORD +57 -57
  55. hatchet_sdk/clients/durable_event_listener.py +0 -329
  56. hatchet_sdk/clients/workflow_listener.py +0 -288
  57. hatchet_sdk/utils/aio.py +0 -43
  58. /hatchet_sdk/clients/{run_event_listener.py → listeners/run_event_listener.py} +0 -0
  59. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/WHEEL +0 -0
  60. {hatchet_sdk-1.2.5.dist-info → hatchet_sdk-1.3.0.dist-info}/entry_points.txt +0 -0
hatchet_sdk/__init__.py CHANGED
@@ -3,8 +3,14 @@ from hatchet_sdk.clients.admin import (
3
3
  ScheduleTriggerWorkflowOptions,
4
4
  TriggerWorkflowOptions,
5
5
  )
6
- from hatchet_sdk.clients.durable_event_listener import RegisterDurableEventRequest
7
6
  from hatchet_sdk.clients.events import PushEventOptions
7
+ from hatchet_sdk.clients.listeners.durable_event_listener import (
8
+ RegisterDurableEventRequest,
9
+ )
10
+ from hatchet_sdk.clients.listeners.run_event_listener import (
11
+ StepRunEventType,
12
+ WorkflowRunEventType,
13
+ )
8
14
  from hatchet_sdk.clients.rest.models.accept_invite_request import AcceptInviteRequest
9
15
 
10
16
  # import models into sdk package
@@ -124,10 +130,6 @@ from hatchet_sdk.clients.rest.models.workflow_version_definition import (
124
130
  WorkflowVersionDefinition,
125
131
  )
126
132
  from hatchet_sdk.clients.rest.models.workflow_version_meta import WorkflowVersionMeta
127
- from hatchet_sdk.clients.run_event_listener import (
128
- StepRunEventType,
129
- WorkflowRunEventType,
130
- )
131
133
  from hatchet_sdk.config import ClientConfig
132
134
  from hatchet_sdk.context.context import Context, DurableContext
133
135
  from hatchet_sdk.context.worker_context import WorkerContext
hatchet_sdk/client.py CHANGED
@@ -1,8 +1,8 @@
1
1
  from hatchet_sdk.clients.admin import AdminClient
2
2
  from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
3
3
  from hatchet_sdk.clients.events import EventClient
4
- from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
5
- from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
4
+ from hatchet_sdk.clients.listeners.run_event_listener import RunEventListenerClient
5
+ from hatchet_sdk.clients.listeners.workflow_listener import PooledWorkflowRunListener
6
6
  from hatchet_sdk.config import ClientConfig
7
7
  from hatchet_sdk.features.cron import CronClient
8
8
  from hatchet_sdk.features.logs import LogsClient
@@ -21,15 +21,15 @@ class Client:
21
21
  event_client: EventClient | None = None,
22
22
  admin_client: AdminClient | None = None,
23
23
  dispatcher_client: DispatcherClient | None = None,
24
- workflow_listener: PooledWorkflowRunListener | None | None = None,
24
+ workflow_listener: PooledWorkflowRunListener | None = None,
25
25
  debug: bool = False,
26
26
  ):
27
27
  self.config = config
28
- self.admin = admin_client or AdminClient(config)
29
28
  self.dispatcher = dispatcher_client or DispatcherClient(config)
30
29
  self.event = event_client or EventClient(config)
31
30
  self.listener = RunEventListenerClient(config)
32
- self.workflow_listener = workflow_listener
31
+ self.workflow_listener = workflow_listener or PooledWorkflowRunListener(config)
32
+
33
33
  self.log_interceptor = config.logger
34
34
  self.debug = debug
35
35
 
@@ -37,7 +37,15 @@ class Client:
37
37
  self.logs = LogsClient(self.config)
38
38
  self.metrics = MetricsClient(self.config)
39
39
  self.rate_limits = RateLimitsClient(self.config)
40
- self.runs = RunsClient(self.config)
40
+ self.runs = RunsClient(
41
+ config=self.config,
42
+ workflow_run_event_listener=self.listener,
43
+ workflow_run_listener=self.workflow_listener,
44
+ )
41
45
  self.scheduled = ScheduledClient(self.config)
42
46
  self.workers = WorkersClient(self.config)
43
47
  self.workflows = WorkflowsClient(self.config)
48
+
49
+ self.admin = admin_client or AdminClient(
50
+ config, self.workflow_listener, self.listener, self.runs
51
+ )
@@ -7,6 +7,8 @@ import grpc
7
7
  from google.protobuf import timestamp_pb2
8
8
  from pydantic import BaseModel, ConfigDict, Field, field_validator
9
9
 
10
+ from hatchet_sdk.clients.listeners.run_event_listener import RunEventListenerClient
11
+ from hatchet_sdk.clients.listeners.workflow_listener import PooledWorkflowRunListener
10
12
  from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
11
13
  from hatchet_sdk.config import ClientConfig
12
14
  from hatchet_sdk.connection import new_conn
@@ -14,6 +16,7 @@ from hatchet_sdk.contracts import workflows_pb2 as v0_workflow_protos
14
16
  from hatchet_sdk.contracts.v1 import workflows_pb2 as workflow_protos
15
17
  from hatchet_sdk.contracts.v1.workflows_pb2_grpc import AdminServiceStub
16
18
  from hatchet_sdk.contracts.workflows_pb2_grpc import WorkflowServiceStub
19
+ from hatchet_sdk.features.runs import RunsClient
17
20
  from hatchet_sdk.metadata import get_metadata
18
21
  from hatchet_sdk.rate_limit import RateLimitDuration
19
22
  from hatchet_sdk.runnables.contextvars import (
@@ -63,14 +66,31 @@ class DedupeViolationErr(Exception):
63
66
 
64
67
 
65
68
  class AdminClient:
66
- def __init__(self, config: ClientConfig):
67
- conn = new_conn(config, False)
69
+ def __init__(
70
+ self,
71
+ config: ClientConfig,
72
+ workflow_run_listener: PooledWorkflowRunListener,
73
+ workflow_run_event_listener: RunEventListenerClient,
74
+ runs_client: RunsClient,
75
+ ):
68
76
  self.config = config
69
- self.client = AdminServiceStub(conn)
70
- self.v0_client = WorkflowServiceStub(conn)
77
+ self.runs_client = runs_client
71
78
  self.token = config.token
72
79
  self.namespace = config.namespace
73
80
 
81
+ self.workflow_run_listener = workflow_run_listener
82
+ self.workflow_run_event_listener = workflow_run_event_listener
83
+
84
+ self.client: AdminServiceStub | None = None
85
+ self.v0_client: WorkflowServiceStub | None = None
86
+
87
+ def _get_or_create_v0_client(self) -> WorkflowServiceStub:
88
+ if self.v0_client is None:
89
+ conn = new_conn(self.config, False)
90
+ self.v0_client = WorkflowServiceStub(conn)
91
+
92
+ return self.v0_client
93
+
74
94
  class TriggerWorkflowRequest(BaseModel):
75
95
  model_config = ConfigDict(extra="ignore")
76
96
 
@@ -199,6 +219,10 @@ class AdminClient:
199
219
  ) -> workflow_protos.CreateWorkflowVersionResponse:
200
220
  opts = self._prepare_put_workflow_request(name, workflow, overrides)
201
221
 
222
+ if self.client is None:
223
+ conn = new_conn(self.config, False)
224
+ self.client = AdminServiceStub(conn)
225
+
202
226
  return cast(
203
227
  workflow_protos.CreateWorkflowVersionResponse,
204
228
  self.client.PutWorkflow(
@@ -218,7 +242,9 @@ class AdminClient:
218
242
  duration, workflow_protos.RateLimitDuration
219
243
  )
220
244
 
221
- self.v0_client.PutRateLimit(
245
+ client = self._get_or_create_v0_client()
246
+
247
+ client.PutRateLimit(
222
248
  v0_workflow_protos.PutRateLimitRequest(
223
249
  key=key,
224
250
  limit=limit,
@@ -245,9 +271,11 @@ class AdminClient:
245
271
  name, schedules, input, options
246
272
  )
247
273
 
274
+ client = self._get_or_create_v0_client()
275
+
248
276
  return cast(
249
277
  v0_workflow_protos.WorkflowVersion,
250
- self.v0_client.ScheduleWorkflow(
278
+ client.ScheduleWorkflow(
251
279
  request,
252
280
  metadata=get_metadata(self.token),
253
281
  ),
@@ -305,11 +333,12 @@ class AdminClient:
305
333
  options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
306
334
  ) -> WorkflowRunRef:
307
335
  request = self._create_workflow_run_request(workflow_name, input, options)
336
+ client = self._get_or_create_v0_client()
308
337
 
309
338
  try:
310
339
  resp = cast(
311
340
  v0_workflow_protos.TriggerWorkflowResponse,
312
- self.v0_client.TriggerWorkflow(
341
+ client.TriggerWorkflow(
313
342
  request,
314
343
  metadata=get_metadata(self.token),
315
344
  ),
@@ -321,7 +350,9 @@ class AdminClient:
321
350
 
322
351
  return WorkflowRunRef(
323
352
  workflow_run_id=resp.workflow_run_id,
324
- config=self.config,
353
+ workflow_run_event_listener=self.workflow_run_event_listener,
354
+ workflow_run_listener=self.workflow_run_listener,
355
+ runs_client=self.runs_client,
325
356
  )
326
357
 
327
358
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@@ -332,13 +363,14 @@ class AdminClient:
332
363
  input: JSONSerializableMapping,
333
364
  options: TriggerWorkflowOptions = TriggerWorkflowOptions(),
334
365
  ) -> WorkflowRunRef:
366
+ client = self._get_or_create_v0_client()
335
367
  async with spawn_index_lock:
336
368
  request = self._create_workflow_run_request(workflow_name, input, options)
337
369
 
338
370
  try:
339
371
  resp = cast(
340
372
  v0_workflow_protos.TriggerWorkflowResponse,
341
- self.v0_client.TriggerWorkflow(
373
+ client.TriggerWorkflow(
342
374
  request,
343
375
  metadata=get_metadata(self.token),
344
376
  ),
@@ -350,8 +382,10 @@ class AdminClient:
350
382
  raise e
351
383
 
352
384
  return WorkflowRunRef(
385
+ runs_client=self.runs_client,
353
386
  workflow_run_id=resp.workflow_run_id,
354
- config=self.config,
387
+ workflow_run_event_listener=self.workflow_run_event_listener,
388
+ workflow_run_listener=self.workflow_run_listener,
355
389
  )
356
390
 
357
391
  def chunk(self, xs: list[T], n: int) -> Generator[list[T], None, None]:
@@ -364,6 +398,7 @@ class AdminClient:
364
398
  self,
365
399
  workflows: list[WorkflowRunTriggerConfig],
366
400
  ) -> list[WorkflowRunRef]:
401
+ client = self._get_or_create_v0_client()
367
402
  bulk_workflows = [
368
403
  self._create_workflow_run_request(
369
404
  workflow.workflow_name, workflow.input, workflow.options
@@ -380,7 +415,7 @@ class AdminClient:
380
415
 
381
416
  resp = cast(
382
417
  v0_workflow_protos.BulkTriggerWorkflowResponse,
383
- self.v0_client.BulkTriggerWorkflow(
418
+ client.BulkTriggerWorkflow(
384
419
  bulk_request,
385
420
  metadata=get_metadata(self.token),
386
421
  ),
@@ -390,7 +425,9 @@ class AdminClient:
390
425
  [
391
426
  WorkflowRunRef(
392
427
  workflow_run_id=workflow_run_id,
393
- config=self.config,
428
+ workflow_run_event_listener=self.workflow_run_event_listener,
429
+ workflow_run_listener=self.workflow_run_listener,
430
+ runs_client=self.runs_client,
394
431
  )
395
432
  for workflow_run_id in resp.workflow_run_ids
396
433
  ]
@@ -403,6 +440,7 @@ class AdminClient:
403
440
  self,
404
441
  workflows: list[WorkflowRunTriggerConfig],
405
442
  ) -> list[WorkflowRunRef]:
443
+ client = self._get_or_create_v0_client()
406
444
  chunks = self.chunk(workflows, MAX_BULK_WORKFLOW_RUN_BATCH_SIZE)
407
445
  refs: list[WorkflowRunRef] = []
408
446
 
@@ -421,7 +459,7 @@ class AdminClient:
421
459
 
422
460
  resp = cast(
423
461
  v0_workflow_protos.BulkTriggerWorkflowResponse,
424
- self.v0_client.BulkTriggerWorkflow(
462
+ client.BulkTriggerWorkflow(
425
463
  bulk_request,
426
464
  metadata=get_metadata(self.token),
427
465
  ),
@@ -431,7 +469,9 @@ class AdminClient:
431
469
  [
432
470
  WorkflowRunRef(
433
471
  workflow_run_id=workflow_run_id,
434
- config=self.config,
472
+ workflow_run_event_listener=self.workflow_run_event_listener,
473
+ workflow_run_listener=self.workflow_run_listener,
474
+ runs_client=self.runs_client,
435
475
  )
436
476
  for workflow_run_id in resp.workflow_run_ids
437
477
  ]
@@ -441,6 +481,8 @@ class AdminClient:
441
481
 
442
482
  def get_workflow_run(self, workflow_run_id: str) -> WorkflowRunRef:
443
483
  return WorkflowRunRef(
484
+ runs_client=self.runs_client,
444
485
  workflow_run_id=workflow_run_id,
445
- config=self.config,
486
+ workflow_run_event_listener=self.workflow_run_event_listener,
487
+ workflow_run_listener=self.workflow_run_listener,
446
488
  )
@@ -12,7 +12,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_valida
12
12
 
13
13
  from hatchet_sdk.clients.event_ts import ThreadSafeEvent, read_with_interrupt
14
14
  from hatchet_sdk.clients.events import proto_timestamp_now
15
- from hatchet_sdk.clients.run_event_listener import (
15
+ from hatchet_sdk.clients.listeners.run_event_listener import (
16
16
  DEFAULT_ACTION_LISTENER_RETRY_INTERVAL,
17
17
  )
18
18
  from hatchet_sdk.config import ClientConfig
@@ -275,7 +275,7 @@ class ActionListener:
275
275
 
276
276
  break
277
277
 
278
- assigned_action = t.result()
278
+ assigned_action, _ = t.result()
279
279
 
280
280
  if assigned_action is cygrpc.EOF:
281
281
  self.retries = self.retries + 1
@@ -35,15 +35,20 @@ DEFAULT_REGISTER_TIMEOUT = 30
35
35
 
36
36
  class DispatcherClient:
37
37
  def __init__(self, config: ClientConfig):
38
- conn = new_conn(config, False)
39
- self.client = DispatcherStub(conn)
40
-
41
38
  self.token = config.token
42
39
  self.config = config
43
40
 
44
41
  ## IMPORTANT: This needs to be created lazily so we don't require
45
42
  ## an event loop to instantiate the client.
46
43
  self.aio_client: DispatcherStub | None = None
44
+ self.client: DispatcherStub | None = None
45
+
46
+ def _get_or_create_client(self) -> DispatcherStub:
47
+ if self.client is None:
48
+ conn = new_conn(self.config, False)
49
+ self.client = DispatcherStub(conn)
50
+
51
+ return self.client
47
52
 
48
53
  async def get_action_listener(
49
54
  self, req: GetActionListenerRequest
@@ -167,23 +172,29 @@ class DispatcherClient:
167
172
  )
168
173
 
169
174
  def put_overrides_data(self, data: OverridesData) -> ActionEventResponse:
175
+ client = self._get_or_create_client()
176
+
170
177
  return cast(
171
178
  ActionEventResponse,
172
- self.client.PutOverridesData(
179
+ client.PutOverridesData(
173
180
  data,
174
181
  metadata=get_metadata(self.token),
175
182
  ),
176
183
  )
177
184
 
178
185
  def release_slot(self, step_run_id: str) -> None:
179
- self.client.ReleaseSlot(
186
+ client = self._get_or_create_client()
187
+
188
+ client.ReleaseSlot(
180
189
  ReleaseSlotRequest(stepRunId=step_run_id),
181
190
  timeout=DEFAULT_REGISTER_TIMEOUT,
182
191
  metadata=get_metadata(self.token),
183
192
  )
184
193
 
185
194
  def refresh_timeout(self, step_run_id: str, increment_by: str) -> None:
186
- self.client.RefreshTimeout(
195
+ client = self._get_or_create_client()
196
+
197
+ client.RefreshTimeout(
187
198
  RefreshTimeoutRequest(
188
199
  stepRunId=step_run_id,
189
200
  incrementTimeoutBy=increment_by,
@@ -203,7 +214,9 @@ class DispatcherClient:
203
214
  else:
204
215
  worker_labels[key] = WorkerLabels(strValue=str(value))
205
216
 
206
- self.client.UpsertWorkerLabels(
217
+ client = self._get_or_create_client()
218
+
219
+ client.UpsertWorkerLabels(
207
220
  UpsertWorkerLabelsRequest(workerId=worker_id, labels=worker_labels),
208
221
  timeout=DEFAULT_REGISTER_TIMEOUT,
209
222
  metadata=get_metadata(self.token),
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import TypeVar, cast
2
+ from typing import Callable, TypeVar, cast, overload
3
3
 
4
4
  import grpc.aio
5
5
  from grpc._cython import cygrpc # type: ignore[attr-defined]
@@ -27,15 +27,35 @@ TRequest = TypeVar("TRequest")
27
27
  TResponse = TypeVar("TResponse")
28
28
 
29
29
 
30
+ @overload
30
31
  async def read_with_interrupt(
31
- listener: grpc.aio.UnaryStreamCall[TRequest, TResponse], interrupt: ThreadSafeEvent
32
- ) -> TResponse:
32
+ listener: grpc.aio.UnaryStreamCall[TRequest, TResponse],
33
+ interrupt: ThreadSafeEvent,
34
+ key_generator: Callable[[TResponse], str],
35
+ ) -> tuple[TResponse, str]: ...
36
+
37
+
38
+ @overload
39
+ async def read_with_interrupt(
40
+ listener: grpc.aio.UnaryStreamCall[TRequest, TResponse],
41
+ interrupt: ThreadSafeEvent,
42
+ key_generator: None = None,
43
+ ) -> tuple[TResponse, None]: ...
44
+
45
+
46
+ async def read_with_interrupt(
47
+ listener: grpc.aio.UnaryStreamCall[TRequest, TResponse],
48
+ interrupt: ThreadSafeEvent,
49
+ key_generator: Callable[[TResponse], str] | None = None,
50
+ ) -> tuple[TResponse, str | None]:
33
51
  try:
34
- result = await listener.read()
52
+ result = cast(TResponse, await listener.read())
35
53
 
36
54
  if result is cygrpc.EOF:
37
55
  raise ValueError("Unexpected EOF")
38
56
 
39
- return cast(TResponse, result)
57
+ key = key_generator(result) if key_generator else None
58
+
59
+ return result, key
40
60
  finally:
41
61
  interrupt.set()
@@ -0,0 +1,125 @@
1
+ import json
2
+ from collections.abc import AsyncIterator
3
+ from typing import Any, Literal, cast
4
+
5
+ import grpc
6
+ import grpc.aio
7
+ from pydantic import BaseModel, ConfigDict
8
+
9
+ from hatchet_sdk.clients.listeners.pooled_listener import PooledListener
10
+ from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
11
+ from hatchet_sdk.connection import new_conn
12
+ from hatchet_sdk.contracts.v1.dispatcher_pb2 import (
13
+ DurableEvent,
14
+ ListenForDurableEventRequest,
15
+ )
16
+ from hatchet_sdk.contracts.v1.dispatcher_pb2 import (
17
+ RegisterDurableEventRequest as RegisterDurableEventRequestProto,
18
+ )
19
+ from hatchet_sdk.contracts.v1.dispatcher_pb2_grpc import V1DispatcherStub
20
+ from hatchet_sdk.contracts.v1.shared.condition_pb2 import DurableEventListenerConditions
21
+ from hatchet_sdk.metadata import get_metadata
22
+ from hatchet_sdk.waits import SleepCondition, UserEventCondition
23
+
24
+ DEFAULT_DURABLE_EVENT_LISTENER_RETRY_INTERVAL = 3 # seconds
25
+ DEFAULT_DURABLE_EVENT_LISTENER_RETRY_COUNT = 5
26
+ DEFAULT_DURABLE_EVENT_LISTENER_INTERRUPT_INTERVAL = 1800 # 30 minutes
27
+
28
+
29
+ class RegisterDurableEventRequest(BaseModel):
30
+ model_config = ConfigDict(arbitrary_types_allowed=True)
31
+
32
+ task_id: str
33
+ signal_key: str
34
+ conditions: list[SleepCondition | UserEventCondition]
35
+
36
+ def to_proto(self) -> RegisterDurableEventRequestProto:
37
+ return RegisterDurableEventRequestProto(
38
+ task_id=self.task_id,
39
+ signal_key=self.signal_key,
40
+ conditions=DurableEventListenerConditions(
41
+ sleep_conditions=[
42
+ c.to_proto()
43
+ for c in self.conditions
44
+ if isinstance(c, SleepCondition)
45
+ ],
46
+ user_event_conditions=[
47
+ c.to_proto()
48
+ for c in self.conditions
49
+ if isinstance(c, UserEventCondition)
50
+ ],
51
+ ),
52
+ )
53
+
54
+
55
+ class ParsedKey(BaseModel):
56
+ task_id: str
57
+ signal_key: str
58
+
59
+
60
+ class DurableEventListener(
61
+ PooledListener[ListenForDurableEventRequest, DurableEvent, V1DispatcherStub]
62
+ ):
63
+ def _generate_key(self, task_id: str, signal_key: str) -> str:
64
+ return task_id + ":" + signal_key
65
+
66
+ def generate_key(self, response: DurableEvent) -> str:
67
+ return self._generate_key(
68
+ task_id=response.task_id,
69
+ signal_key=response.signal_key,
70
+ )
71
+
72
+ def parse_key(self, key: str) -> ParsedKey:
73
+ task_id, signal_key = key.split(":", maxsplit=1)
74
+
75
+ return ParsedKey(
76
+ task_id=task_id,
77
+ signal_key=signal_key,
78
+ )
79
+
80
+ async def create_subscription(
81
+ self,
82
+ request: AsyncIterator[ListenForDurableEventRequest],
83
+ metadata: tuple[tuple[str, str]],
84
+ ) -> grpc.aio.UnaryStreamCall[ListenForDurableEventRequest, DurableEvent]:
85
+ if self.client is None:
86
+ conn = new_conn(self.config, True)
87
+ self.client = V1DispatcherStub(conn)
88
+
89
+ return cast(
90
+ grpc.aio.UnaryStreamCall[ListenForDurableEventRequest, DurableEvent],
91
+ self.client.ListenForDurableEvent(
92
+ request, # type: ignore[arg-type]
93
+ metadata=metadata,
94
+ ),
95
+ )
96
+
97
+ def create_request_body(self, item: str) -> ListenForDurableEventRequest:
98
+ key = self.parse_key(item)
99
+ return ListenForDurableEventRequest(
100
+ task_id=key.task_id,
101
+ signal_key=key.signal_key,
102
+ )
103
+
104
+ @tenacity_retry
105
+ def register_durable_event(
106
+ self, request: RegisterDurableEventRequest
107
+ ) -> Literal[True]:
108
+ conn = new_conn(self.config, True)
109
+ client = V1DispatcherStub(conn)
110
+
111
+ client.RegisterDurableEvent(
112
+ request.to_proto(),
113
+ timeout=5,
114
+ metadata=get_metadata(self.token),
115
+ )
116
+
117
+ return True
118
+
119
+ @tenacity_retry
120
+ async def result(self, task_id: str, signal_key: str) -> dict[str, Any]:
121
+ key = self._generate_key(task_id, signal_key)
122
+
123
+ event = await self.subscribe(key)
124
+
125
+ return cast(dict[str, Any], json.loads(event.data.decode("utf-8")))