hatchet-sdk 1.0.3__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (42) hide show
  1. hatchet_sdk/__init__.py +2 -1
  2. hatchet_sdk/client.py +3 -16
  3. hatchet_sdk/clients/admin.py +7 -32
  4. hatchet_sdk/clients/dispatcher/action_listener.py +4 -10
  5. hatchet_sdk/clients/dispatcher/dispatcher.py +35 -8
  6. hatchet_sdk/clients/durable_event_listener.py +11 -12
  7. hatchet_sdk/clients/events.py +11 -15
  8. hatchet_sdk/clients/rest/models/tenant_resource.py +2 -0
  9. hatchet_sdk/clients/rest/models/workflow_runs_metrics.py +1 -5
  10. hatchet_sdk/clients/run_event_listener.py +55 -40
  11. hatchet_sdk/clients/v1/api_client.py +1 -38
  12. hatchet_sdk/clients/workflow_listener.py +9 -10
  13. hatchet_sdk/contracts/dispatcher_pb2.py +46 -46
  14. hatchet_sdk/contracts/dispatcher_pb2.pyi +4 -2
  15. hatchet_sdk/contracts/dispatcher_pb2_grpc.py +1 -1
  16. hatchet_sdk/contracts/events_pb2_grpc.py +1 -1
  17. hatchet_sdk/contracts/v1/dispatcher_pb2_grpc.py +1 -1
  18. hatchet_sdk/contracts/v1/workflows_pb2_grpc.py +1 -1
  19. hatchet_sdk/contracts/workflows_pb2_grpc.py +1 -1
  20. hatchet_sdk/features/cron.py +5 -4
  21. hatchet_sdk/features/logs.py +2 -1
  22. hatchet_sdk/features/metrics.py +4 -3
  23. hatchet_sdk/features/rate_limits.py +1 -1
  24. hatchet_sdk/features/runs.py +8 -7
  25. hatchet_sdk/features/scheduled.py +5 -4
  26. hatchet_sdk/features/workers.py +4 -3
  27. hatchet_sdk/features/workflows.py +4 -3
  28. hatchet_sdk/metadata.py +2 -2
  29. hatchet_sdk/runnables/standalone.py +3 -18
  30. hatchet_sdk/runnables/task.py +4 -0
  31. hatchet_sdk/runnables/workflow.py +28 -0
  32. hatchet_sdk/utils/aio.py +43 -0
  33. hatchet_sdk/worker/action_listener_process.py +7 -1
  34. hatchet_sdk/worker/runner/run_loop_manager.py +1 -1
  35. hatchet_sdk/worker/runner/runner.py +21 -5
  36. hatchet_sdk/workflow_run.py +7 -20
  37. hatchet_sdk-1.2.0.dist-info/METADATA +109 -0
  38. {hatchet_sdk-1.0.3.dist-info → hatchet_sdk-1.2.0.dist-info}/RECORD +40 -40
  39. hatchet_sdk/utils/aio_utils.py +0 -18
  40. hatchet_sdk-1.0.3.dist-info/METADATA +0 -42
  41. {hatchet_sdk-1.0.3.dist-info → hatchet_sdk-1.2.0.dist-info}/WHEEL +0 -0
  42. {hatchet_sdk-1.0.3.dist-info → hatchet_sdk-1.2.0.dist-info}/entry_points.txt +0 -0
hatchet_sdk/__init__.py CHANGED
@@ -138,7 +138,7 @@ from hatchet_sdk.contracts.workflows_pb2 import (
138
138
  )
139
139
  from hatchet_sdk.features.runs import BulkCancelReplayOpts, RunFilter
140
140
  from hatchet_sdk.hatchet import Hatchet
141
- from hatchet_sdk.runnables.task import Task
141
+ from hatchet_sdk.runnables.task import NonRetryableException, Task
142
142
  from hatchet_sdk.runnables.types import (
143
143
  ConcurrencyExpression,
144
144
  ConcurrencyLimitStrategy,
@@ -269,4 +269,5 @@ __all__ = [
269
269
  "BulkCancelReplayOpts",
270
270
  "RunFilter",
271
271
  "V1TaskStatus",
272
+ "NonRetryableException",
272
273
  ]
hatchet_sdk/client.py CHANGED
@@ -1,14 +1,9 @@
1
- import asyncio
2
-
3
- import grpc
4
-
5
1
  from hatchet_sdk.clients.admin import AdminClient
6
2
  from hatchet_sdk.clients.dispatcher.dispatcher import DispatcherClient
7
- from hatchet_sdk.clients.events import EventClient, new_event
3
+ from hatchet_sdk.clients.events import EventClient
8
4
  from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
9
5
  from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
10
6
  from hatchet_sdk.config import ClientConfig
11
- from hatchet_sdk.connection import new_conn
12
7
  from hatchet_sdk.features.cron import CronClient
13
8
  from hatchet_sdk.features.logs import LogsClient
14
9
  from hatchet_sdk.features.metrics import MetricsClient
@@ -29,21 +24,13 @@ class Client:
29
24
  workflow_listener: PooledWorkflowRunListener | None | None = None,
30
25
  debug: bool = False,
31
26
  ):
32
- try:
33
- loop = asyncio.get_running_loop()
34
- except RuntimeError:
35
- loop = asyncio.new_event_loop()
36
- asyncio.set_event_loop(loop)
37
-
38
- conn: grpc.Channel = new_conn(config, False)
39
-
40
27
  self.config = config
41
28
  self.admin = admin_client or AdminClient(config)
42
29
  self.dispatcher = dispatcher_client or DispatcherClient(config)
43
- self.event = event_client or new_event(conn, config)
30
+ self.event = event_client or EventClient(config)
44
31
  self.listener = RunEventListenerClient(config)
45
32
  self.workflow_listener = workflow_listener
46
- self.logInterceptor = config.logger
33
+ self.log_interceptor = config.logger
47
34
  self.debug = debug
48
35
 
49
36
  self.cron = CronClient(self.config)
@@ -8,8 +8,6 @@ from google.protobuf import timestamp_pb2
8
8
  from pydantic import BaseModel, ConfigDict, Field, field_validator
9
9
 
10
10
  from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
11
- from hatchet_sdk.clients.run_event_listener import RunEventListenerClient
12
- from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
13
11
  from hatchet_sdk.config import ClientConfig
14
12
  from hatchet_sdk.connection import new_conn
15
13
  from hatchet_sdk.contracts import workflows_pb2 as v0_workflow_protos
@@ -64,14 +62,11 @@ class AdminClient:
64
62
  def __init__(self, config: ClientConfig):
65
63
  conn = new_conn(config, False)
66
64
  self.config = config
67
- self.client = AdminServiceStub(conn) # type: ignore[no-untyped-call]
68
- self.v0_client = WorkflowServiceStub(conn) # type: ignore[no-untyped-call]
65
+ self.client = AdminServiceStub(conn)
66
+ self.v0_client = WorkflowServiceStub(conn)
69
67
  self.token = config.token
70
- self.listener_client = RunEventListenerClient(config=config)
71
68
  self.namespace = config.namespace
72
69
 
73
- self.pooled_workflow_listener: PooledWorkflowRunListener | None = None
74
-
75
70
  class TriggerWorkflowRequest(BaseModel):
76
71
  model_config = ConfigDict(extra="ignore")
77
72
 
@@ -307,9 +302,6 @@ class AdminClient:
307
302
  ) -> WorkflowRunRef:
308
303
  request = self._create_workflow_run_request(workflow_name, input, options)
309
304
 
310
- if not self.pooled_workflow_listener:
311
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
312
-
313
305
  try:
314
306
  resp = cast(
315
307
  v0_workflow_protos.TriggerWorkflowResponse,
@@ -325,8 +317,7 @@ class AdminClient:
325
317
 
326
318
  return WorkflowRunRef(
327
319
  workflow_run_id=resp.workflow_run_id,
328
- workflow_listener=self.pooled_workflow_listener,
329
- workflow_run_event_listener=self.listener_client,
320
+ config=self.config,
330
321
  )
331
322
 
332
323
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@@ -343,9 +334,6 @@ class AdminClient:
343
334
  async with spawn_index_lock:
344
335
  request = self._create_workflow_run_request(workflow_name, input, options)
345
336
 
346
- if not self.pooled_workflow_listener:
347
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
348
-
349
337
  try:
350
338
  resp = cast(
351
339
  v0_workflow_protos.TriggerWorkflowResponse,
@@ -362,8 +350,7 @@ class AdminClient:
362
350
 
363
351
  return WorkflowRunRef(
364
352
  workflow_run_id=resp.workflow_run_id,
365
- workflow_listener=self.pooled_workflow_listener,
366
- workflow_run_event_listener=self.listener_client,
353
+ config=self.config,
367
354
  )
368
355
 
369
356
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@@ -372,9 +359,6 @@ class AdminClient:
372
359
  self,
373
360
  workflows: list[WorkflowRunTriggerConfig],
374
361
  ) -> list[WorkflowRunRef]:
375
- if not self.pooled_workflow_listener:
376
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
377
-
378
362
  bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
379
363
  workflows=[
380
364
  self._create_workflow_run_request(
@@ -395,8 +379,7 @@ class AdminClient:
395
379
  return [
396
380
  WorkflowRunRef(
397
381
  workflow_run_id=workflow_run_id,
398
- workflow_listener=self.pooled_workflow_listener,
399
- workflow_run_event_listener=self.listener_client,
382
+ config=self.config,
400
383
  )
401
384
  for workflow_run_id in resp.workflow_run_ids
402
385
  ]
@@ -409,9 +392,6 @@ class AdminClient:
409
392
  ## IMPORTANT: The `pooled_workflow_listener` must be created 1) lazily, and not at `init` time, and 2) on the
410
393
  ## main thread. If 1) is not followed, you'll get an error about something being attached to the wrong event
411
394
  ## loop. If 2) is not followed, you'll get an error about the event loop not being set up.
412
- if not self.pooled_workflow_listener:
413
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
414
-
415
395
  async with spawn_index_lock:
416
396
  bulk_request = v0_workflow_protos.BulkTriggerWorkflowRequest(
417
397
  workflows=[
@@ -433,18 +413,13 @@ class AdminClient:
433
413
  return [
434
414
  WorkflowRunRef(
435
415
  workflow_run_id=workflow_run_id,
436
- workflow_listener=self.pooled_workflow_listener,
437
- workflow_run_event_listener=self.listener_client,
416
+ config=self.config,
438
417
  )
439
418
  for workflow_run_id in resp.workflow_run_ids
440
419
  ]
441
420
 
442
421
  def get_workflow_run(self, workflow_run_id: str) -> WorkflowRunRef:
443
- if not self.pooled_workflow_listener:
444
- self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
445
-
446
422
  return WorkflowRunRef(
447
423
  workflow_run_id=workflow_run_id,
448
- workflow_listener=self.pooled_workflow_listener,
449
- workflow_run_event_listener=self.listener_client,
424
+ config=self.config,
450
425
  )
@@ -152,7 +152,7 @@ class ActionListener:
152
152
  self.config = config
153
153
  self.worker_id = worker_id
154
154
 
155
- self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
155
+ self.aio_client = DispatcherStub(new_conn(self.config, True))
156
156
  self.token = self.config.token
157
157
 
158
158
  self.retries = 0
@@ -232,14 +232,8 @@ class ActionListener:
232
232
  if self.heartbeat_task is not None:
233
233
  return
234
234
 
235
- try:
236
- loop = asyncio.get_event_loop()
237
- except RuntimeError as e:
238
- if str(e).startswith("There is no current event loop in thread"):
239
- loop = asyncio.new_event_loop()
240
- asyncio.set_event_loop(loop)
241
- else:
242
- raise e
235
+ loop = asyncio.get_event_loop()
236
+
243
237
  self.heartbeat_task = loop.create_task(self.heartbeat())
244
238
 
245
239
  def __aiter__(self) -> AsyncGenerator[Action | None, None]:
@@ -386,7 +380,7 @@ class ActionListener:
386
380
  f"action listener connection interrupted, retrying... ({self.retries}/{DEFAULT_ACTION_LISTENER_RETRY_COUNT})"
387
381
  )
388
382
 
389
- self.aio_client = DispatcherStub(new_conn(self.config, True)) # type: ignore[no-untyped-call]
383
+ self.aio_client = DispatcherStub(new_conn(self.config, True))
390
384
 
391
385
  if self.listen_strategy == "v2":
392
386
  # we should await for the listener to be established before
@@ -34,20 +34,23 @@ DEFAULT_REGISTER_TIMEOUT = 30
34
34
 
35
35
 
36
36
  class DispatcherClient:
37
- config: ClientConfig
38
-
39
37
  def __init__(self, config: ClientConfig):
40
38
  conn = new_conn(config, False)
41
- self.client = DispatcherStub(conn) # type: ignore[no-untyped-call]
39
+ self.client = DispatcherStub(conn)
42
40
 
43
- aio_conn = new_conn(config, True)
44
- self.aio_client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
45
41
  self.token = config.token
46
42
  self.config = config
47
43
 
44
+ ## IMPORTANT: This needs to be created lazily so we don't require
45
+ ## an event loop to instantiate the client.
46
+ self.aio_client: DispatcherStub | None = None
47
+
48
48
  async def get_action_listener(
49
49
  self, req: GetActionListenerRequest
50
50
  ) -> ActionListener:
51
+ if not self.aio_client:
52
+ aio_conn = new_conn(self.config, True)
53
+ self.aio_client = DispatcherStub(aio_conn)
51
54
 
52
55
  # Override labels with the preset labels
53
56
  preset_labels = self.config.worker_preset_labels
@@ -73,10 +76,16 @@ class DispatcherClient:
73
76
  return ActionListener(self.config, response.workerId)
74
77
 
75
78
  async def send_step_action_event(
76
- self, action: Action, event_type: StepActionEventType, payload: str
79
+ self,
80
+ action: Action,
81
+ event_type: StepActionEventType,
82
+ payload: str,
83
+ should_not_retry: bool,
77
84
  ) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse] | None:
78
85
  try:
79
- return await self._try_send_step_action_event(action, event_type, payload)
86
+ return await self._try_send_step_action_event(
87
+ action, event_type, payload, should_not_retry
88
+ )
80
89
  except Exception as e:
81
90
  # for step action events, send a failure event when we cannot send the completed event
82
91
  if (
@@ -87,14 +96,23 @@ class DispatcherClient:
87
96
  action,
88
97
  STEP_EVENT_TYPE_FAILED,
89
98
  "Failed to send finished event: " + str(e),
99
+ should_not_retry=True,
90
100
  )
91
101
 
92
102
  return None
93
103
 
94
104
  @tenacity_retry
95
105
  async def _try_send_step_action_event(
96
- self, action: Action, event_type: StepActionEventType, payload: str
106
+ self,
107
+ action: Action,
108
+ event_type: StepActionEventType,
109
+ payload: str,
110
+ should_not_retry: bool,
97
111
  ) -> grpc.aio.UnaryUnaryCall[StepActionEvent, ActionEventResponse]:
112
+ if not self.aio_client:
113
+ aio_conn = new_conn(self.config, True)
114
+ self.aio_client = DispatcherStub(aio_conn)
115
+
98
116
  event_timestamp = Timestamp()
99
117
  event_timestamp.GetCurrentTime()
100
118
 
@@ -109,6 +127,7 @@ class DispatcherClient:
109
127
  eventType=event_type,
110
128
  eventPayload=payload,
111
129
  retryCount=action.retry_count,
130
+ shouldNotRetry=should_not_retry,
112
131
  )
113
132
 
114
133
  return cast(
@@ -122,6 +141,10 @@ class DispatcherClient:
122
141
  async def send_group_key_action_event(
123
142
  self, action: Action, event_type: GroupKeyActionEventType, payload: str
124
143
  ) -> grpc.aio.UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse]:
144
+ if not self.aio_client:
145
+ aio_conn = new_conn(self.config, True)
146
+ self.aio_client = DispatcherStub(aio_conn)
147
+
125
148
  event_timestamp = Timestamp()
126
149
  event_timestamp.GetCurrentTime()
127
150
 
@@ -191,6 +214,10 @@ class DispatcherClient:
191
214
  worker_id: str | None,
192
215
  labels: dict[str, str | int],
193
216
  ) -> None:
217
+ if not self.aio_client:
218
+ aio_conn = new_conn(self.config, True)
219
+ self.aio_client = DispatcherStub(aio_conn)
220
+
194
221
  worker_labels = {}
195
222
 
196
223
  for key, value in labels.items():
@@ -84,14 +84,6 @@ class RegisterDurableEventRequest(BaseModel):
84
84
 
85
85
  class DurableEventListener:
86
86
  def __init__(self, config: ClientConfig):
87
- try:
88
- asyncio.get_running_loop()
89
- except RuntimeError:
90
- loop = asyncio.new_event_loop()
91
- asyncio.set_event_loop(loop)
92
-
93
- conn = new_conn(config, True)
94
- self.client = V1DispatcherStub(conn) # type: ignore[no-untyped-call]
95
87
  self.token = config.token
96
88
  self.config = config
97
89
 
@@ -129,11 +121,14 @@ class DurableEventListener:
129
121
  self.interrupt.set()
130
122
 
131
123
  async def _init_producer(self) -> None:
124
+ conn = new_conn(self.config, True)
125
+ client = V1DispatcherStub(conn)
126
+
132
127
  try:
133
128
  if not self.listener:
134
129
  while True:
135
130
  try:
136
- self.listener = await self._retry_subscribe()
131
+ self.listener = await self._retry_subscribe(client)
137
132
 
138
133
  logger.debug("Workflow run listener connected.")
139
134
 
@@ -282,6 +277,7 @@ class DurableEventListener:
282
277
 
283
278
  async def _retry_subscribe(
284
279
  self,
280
+ client: V1DispatcherStub,
285
281
  ) -> grpc.aio.UnaryStreamCall[ListenForDurableEventRequest, DurableEvent]:
286
282
  retries = 0
287
283
 
@@ -298,8 +294,8 @@ class DurableEventListener:
298
294
  grpc.aio.UnaryStreamCall[
299
295
  ListenForDurableEventRequest, DurableEvent
300
296
  ],
301
- self.client.ListenForDurableEvent(
302
- self._request(),
297
+ client.ListenForDurableEvent(
298
+ self._request(), # type: ignore[arg-type]
303
299
  metadata=get_metadata(self.token),
304
300
  ),
305
301
  )
@@ -315,7 +311,10 @@ class DurableEventListener:
315
311
  def register_durable_event(
316
312
  self, request: RegisterDurableEventRequest
317
313
  ) -> Literal[True]:
318
- self.client.RegisterDurableEvent(
314
+ conn = new_conn(self.config, True)
315
+ client = V1DispatcherStub(conn)
316
+
317
+ client.RegisterDurableEvent(
319
318
  request.to_proto(),
320
319
  timeout=5,
321
320
  metadata=get_metadata(self.token),
@@ -3,15 +3,16 @@ import datetime
3
3
  import json
4
4
  from typing import List, cast
5
5
 
6
- import grpc
7
6
  from google.protobuf import timestamp_pb2
8
7
  from pydantic import BaseModel, Field
9
8
 
10
9
  from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
11
10
  from hatchet_sdk.config import ClientConfig
11
+ from hatchet_sdk.connection import new_conn
12
12
  from hatchet_sdk.contracts.events_pb2 import (
13
13
  BulkPushEventRequest,
14
14
  Event,
15
+ Events,
15
16
  PushEventRequest,
16
17
  PutLogRequest,
17
18
  PutStreamEventRequest,
@@ -21,13 +22,6 @@ from hatchet_sdk.metadata import get_metadata
21
22
  from hatchet_sdk.utils.typing import JSONSerializableMapping
22
23
 
23
24
 
24
- def new_event(conn: grpc.Channel, config: ClientConfig) -> "EventClient":
25
- return EventClient(
26
- client=EventsServiceStub(conn), # type: ignore[no-untyped-call]
27
- config=config,
28
- )
29
-
30
-
31
25
  def proto_timestamp_now() -> timestamp_pb2.Timestamp:
32
26
  t = datetime.datetime.now().timestamp()
33
27
  seconds = int(t)
@@ -52,8 +46,10 @@ class BulkPushEventWithMetadata(BaseModel):
52
46
 
53
47
 
54
48
  class EventClient:
55
- def __init__(self, client: EventsServiceStub, config: ClientConfig):
56
- self.client = client
49
+ def __init__(self, config: ClientConfig):
50
+ conn = new_conn(config, False)
51
+ self.client = EventsServiceStub(conn)
52
+
57
53
  self.token = config.token
58
54
  self.namespace = config.namespace
59
55
 
@@ -146,11 +142,11 @@ class EventClient:
146
142
  ]
147
143
  )
148
144
 
149
- response = self.client.BulkPush(bulk_request, metadata=get_metadata(self.token))
150
-
151
- return cast(
152
- list[Event],
153
- response.events,
145
+ return list(
146
+ cast(
147
+ Events,
148
+ self.client.BulkPush(bulk_request, metadata=get_metadata(self.token)),
149
+ ).events
154
150
  )
155
151
 
156
152
  def log(self, message: str, step_run_id: str) -> None:
@@ -29,8 +29,10 @@ class TenantResource(str, Enum):
29
29
  allowed enum values
30
30
  """
31
31
  WORKER = "WORKER"
32
+ WORKER_SLOT = "WORKER_SLOT"
32
33
  EVENT = "EVENT"
33
34
  WORKFLOW_RUN = "WORKFLOW_RUN"
35
+ TASK_RUN = "TASK_RUN"
34
36
  CRON = "CRON"
35
37
  SCHEDULE = "SCHEDULE"
36
38
 
@@ -22,17 +22,13 @@ from typing import Any, ClassVar, Dict, List, Optional, Set
22
22
  from pydantic import BaseModel, ConfigDict
23
23
  from typing_extensions import Self
24
24
 
25
- from hatchet_sdk.clients.rest.models.workflow_runs_metrics_counts import (
26
- WorkflowRunsMetricsCounts,
27
- )
28
-
29
25
 
30
26
  class WorkflowRunsMetrics(BaseModel):
31
27
  """
32
28
  WorkflowRunsMetrics
33
29
  """ # noqa: E501
34
30
 
35
- counts: Optional[WorkflowRunsMetricsCounts] = None
31
+ counts: Optional[Dict[str, Any]] = None
36
32
  __properties: ClassVar[List[str]] = ["counts"]
37
33
 
38
34
  model_config = ConfigDict(
@@ -1,6 +1,8 @@
1
1
  import asyncio
2
2
  from enum import Enum
3
- from typing import Any, AsyncGenerator, Callable, Generator, cast
3
+ from queue import Empty, Queue
4
+ from threading import Thread
5
+ from typing import Any, AsyncGenerator, Callable, Generator, Literal, TypeVar, cast
4
6
 
5
7
  import grpc
6
8
  from pydantic import BaseModel
@@ -55,6 +57,8 @@ workflow_run_event_type_mapping = {
55
57
  ResourceEventType.RESOURCE_EVENT_TYPE_TIMED_OUT: WorkflowRunEventType.WORKFLOW_RUN_EVENT_TYPE_TIMED_OUT,
56
58
  }
57
59
 
60
+ T = TypeVar("T")
61
+
58
62
 
59
63
  class StepRunEvent(BaseModel):
60
64
  type: StepRunEventType
@@ -64,18 +68,20 @@ class StepRunEvent(BaseModel):
64
68
  class RunEventListener:
65
69
  def __init__(
66
70
  self,
67
- client: DispatcherStub,
68
- token: str,
71
+ config: ClientConfig,
69
72
  workflow_run_id: str | None = None,
70
73
  additional_meta_kv: tuple[str, str] | None = None,
71
74
  ):
72
- self.client = client
75
+ self.config = config
73
76
  self.stop_signal = False
74
- self.token = token
75
77
 
76
78
  self.workflow_run_id = workflow_run_id
77
79
  self.additional_meta_kv = additional_meta_kv
78
80
 
81
+ ## IMPORTANT: This needs to be created lazily so we don't require
82
+ ## an event loop to instantiate the client.
83
+ self.client: DispatcherStub | None = None
84
+
79
85
  def abort(self) -> None:
80
86
  self.stop_signal = True
81
87
 
@@ -85,27 +91,46 @@ class RunEventListener:
85
91
  async def __anext__(self) -> StepRunEvent:
86
92
  return await self._generator().__anext__()
87
93
 
88
- def __iter__(self) -> Generator[StepRunEvent, None, None]:
89
- try:
90
- loop = asyncio.get_event_loop()
91
- except RuntimeError as e:
92
- if str(e).startswith("There is no current event loop in thread"):
93
- loop = asyncio.new_event_loop()
94
- asyncio.set_event_loop(loop)
95
- else:
96
- raise e
94
+ def async_to_sync_thread(
95
+ self, async_iter: AsyncGenerator[T, None]
96
+ ) -> Generator[T, None, None]:
97
+ q = Queue[T | Literal["DONE"]]()
98
+ done_sentinel: Literal["DONE"] = "DONE"
99
+
100
+ def runner() -> None:
101
+ loop = asyncio.new_event_loop()
102
+ asyncio.set_event_loop(loop)
97
103
 
98
- async_iter = self.__aiter__()
104
+ async def consume() -> None:
105
+ try:
106
+ async for item in async_iter:
107
+ q.put(item)
108
+ finally:
109
+ q.put(done_sentinel)
110
+
111
+ try:
112
+ loop.run_until_complete(consume())
113
+ finally:
114
+ loop.stop()
115
+ loop.close()
116
+
117
+ thread = Thread(target=runner)
118
+ thread.start()
99
119
 
100
120
  while True:
101
121
  try:
102
- future = asyncio.ensure_future(async_iter.__anext__())
103
- yield loop.run_until_complete(future)
104
- except StopAsyncIteration:
105
- break
106
- except Exception as e:
107
- print(f"Error in synchronous iterator: {e}")
108
- break
122
+ item = q.get(timeout=1)
123
+ if item == "DONE":
124
+ break
125
+ yield item
126
+ except Empty:
127
+ continue
128
+
129
+ thread.join()
130
+
131
+ def __iter__(self) -> Generator[StepRunEvent, None, None]:
132
+ for item in self.async_to_sync_thread(self.__aiter__()):
133
+ yield item
109
134
 
110
135
  async def _generator(self) -> AsyncGenerator[StepRunEvent, None]:
111
136
  while True:
@@ -172,6 +197,10 @@ class RunEventListener:
172
197
  async def retry_subscribe(self) -> AsyncGenerator[WorkflowEvent, None]:
173
198
  retries = 0
174
199
 
200
+ if self.client is None:
201
+ aio_conn = new_conn(self.config, True)
202
+ self.client = DispatcherStub(aio_conn)
203
+
175
204
  while retries < DEFAULT_ACTION_LISTENER_RETRY_COUNT:
176
205
  try:
177
206
  if retries > 0:
@@ -184,7 +213,7 @@ class RunEventListener:
184
213
  SubscribeToWorkflowEventsRequest(
185
214
  workflowRunId=self.workflow_run_id,
186
215
  ),
187
- metadata=get_metadata(self.token),
216
+ metadata=get_metadata(self.config.token),
188
217
  ),
189
218
  )
190
219
  elif self.additional_meta_kv is not None:
@@ -195,7 +224,7 @@ class RunEventListener:
195
224
  additionalMetaKey=self.additional_meta_kv[0],
196
225
  additionalMetaValue=self.additional_meta_kv[1],
197
226
  ),
198
- metadata=get_metadata(self.token),
227
+ metadata=get_metadata(self.config.token),
199
228
  ),
200
229
  )
201
230
  else:
@@ -212,30 +241,16 @@ class RunEventListener:
212
241
 
213
242
  class RunEventListenerClient:
214
243
  def __init__(self, config: ClientConfig):
215
- self.token = config.token
216
244
  self.config = config
217
- self.client: DispatcherStub | None = None
218
245
 
219
246
  def stream_by_run_id(self, workflow_run_id: str) -> RunEventListener:
220
247
  return self.stream(workflow_run_id)
221
248
 
222
249
  def stream(self, workflow_run_id: str) -> RunEventListener:
223
- if not self.client:
224
- aio_conn = new_conn(self.config, True)
225
- self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
226
-
227
- return RunEventListener(
228
- client=self.client, token=self.token, workflow_run_id=workflow_run_id
229
- )
250
+ return RunEventListener(config=self.config, workflow_run_id=workflow_run_id)
230
251
 
231
252
  def stream_by_additional_metadata(self, key: str, value: str) -> RunEventListener:
232
- if not self.client:
233
- aio_conn = new_conn(self.config, True)
234
- self.client = DispatcherStub(aio_conn) # type: ignore[no-untyped-call]
235
-
236
- return RunEventListener(
237
- client=self.client, token=self.token, additional_meta_kv=(key, value)
238
- )
253
+ return RunEventListener(config=self.config, additional_meta_kv=(key, value))
239
254
 
240
255
  async def on(
241
256
  self, workflow_run_id: str, handler: Callable[[StepRunEvent], Any] | None = None
@@ -1,6 +1,4 @@
1
- import asyncio
2
- from concurrent.futures import ThreadPoolExecutor
3
- from typing import AsyncContextManager, Callable, Coroutine, ParamSpec, TypeVar
1
+ from typing import AsyncContextManager, ParamSpec, TypeVar
4
2
 
5
3
  from hatchet_sdk.clients.rest.api_client import ApiClient
6
4
  from hatchet_sdk.clients.rest.configuration import Configuration
@@ -44,38 +42,3 @@ class BaseRestClient:
44
42
 
45
43
  def client(self) -> AsyncContextManager[ApiClient]:
46
44
  return ApiClient(self.api_config)
47
-
48
- def _run_async_function_do_not_use_directly(
49
- self,
50
- async_func: Callable[P, Coroutine[Y, S, R]],
51
- *args: P.args,
52
- **kwargs: P.kwargs,
53
- ) -> R:
54
- loop = asyncio.new_event_loop()
55
- asyncio.set_event_loop(loop)
56
- try:
57
- return loop.run_until_complete(async_func(*args, **kwargs))
58
- finally:
59
- loop.close()
60
-
61
- def _run_async_from_sync(
62
- self,
63
- async_func: Callable[P, Coroutine[Y, S, R]],
64
- *args: P.args,
65
- **kwargs: P.kwargs,
66
- ) -> R:
67
- try:
68
- loop = asyncio.get_event_loop()
69
- except RuntimeError:
70
- loop = None
71
-
72
- if loop and loop.is_running():
73
- return loop.run_until_complete(async_func(*args, **kwargs))
74
- else:
75
- with ThreadPoolExecutor() as executor:
76
- future = executor.submit(
77
- lambda: self._run_async_function_do_not_use_directly(
78
- async_func, *args, **kwargs
79
- )
80
- )
81
- return future.result()