hatchet-sdk 1.12.3__py3-none-any.whl → 1.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (48) hide show
  1. hatchet_sdk/__init__.py +46 -40
  2. hatchet_sdk/clients/admin.py +18 -23
  3. hatchet_sdk/clients/dispatcher/action_listener.py +4 -3
  4. hatchet_sdk/clients/dispatcher/dispatcher.py +1 -4
  5. hatchet_sdk/clients/event_ts.py +2 -1
  6. hatchet_sdk/clients/events.py +16 -12
  7. hatchet_sdk/clients/listeners/durable_event_listener.py +4 -2
  8. hatchet_sdk/clients/listeners/pooled_listener.py +2 -2
  9. hatchet_sdk/clients/listeners/run_event_listener.py +7 -8
  10. hatchet_sdk/clients/listeners/workflow_listener.py +14 -6
  11. hatchet_sdk/clients/rest/api_response.py +3 -2
  12. hatchet_sdk/clients/rest/tenacity_utils.py +6 -8
  13. hatchet_sdk/config.py +2 -0
  14. hatchet_sdk/connection.py +10 -4
  15. hatchet_sdk/context/context.py +170 -46
  16. hatchet_sdk/context/worker_context.py +4 -7
  17. hatchet_sdk/contracts/dispatcher_pb2.py +38 -38
  18. hatchet_sdk/contracts/dispatcher_pb2.pyi +4 -2
  19. hatchet_sdk/contracts/events_pb2.py +13 -13
  20. hatchet_sdk/contracts/events_pb2.pyi +4 -2
  21. hatchet_sdk/contracts/v1/workflows_pb2.py +1 -1
  22. hatchet_sdk/contracts/v1/workflows_pb2.pyi +2 -2
  23. hatchet_sdk/exceptions.py +99 -1
  24. hatchet_sdk/features/cron.py +2 -2
  25. hatchet_sdk/features/filters.py +3 -3
  26. hatchet_sdk/features/runs.py +4 -4
  27. hatchet_sdk/features/scheduled.py +8 -9
  28. hatchet_sdk/hatchet.py +65 -64
  29. hatchet_sdk/opentelemetry/instrumentor.py +20 -20
  30. hatchet_sdk/runnables/action.py +1 -2
  31. hatchet_sdk/runnables/contextvars.py +19 -0
  32. hatchet_sdk/runnables/task.py +37 -29
  33. hatchet_sdk/runnables/types.py +9 -8
  34. hatchet_sdk/runnables/workflow.py +57 -42
  35. hatchet_sdk/utils/proto_enums.py +4 -4
  36. hatchet_sdk/utils/timedelta_to_expression.py +2 -3
  37. hatchet_sdk/utils/typing.py +11 -17
  38. hatchet_sdk/waits.py +6 -5
  39. hatchet_sdk/worker/action_listener_process.py +33 -13
  40. hatchet_sdk/worker/runner/run_loop_manager.py +15 -11
  41. hatchet_sdk/worker/runner/runner.py +102 -92
  42. hatchet_sdk/worker/runner/utils/capture_logs.py +72 -31
  43. hatchet_sdk/worker/worker.py +29 -25
  44. hatchet_sdk/workflow_run.py +4 -2
  45. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/METADATA +1 -1
  46. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/RECORD +48 -48
  47. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/WHEEL +0 -0
  48. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.13.0.dist-info}/entry_points.txt +0 -0
@@ -1,20 +1,11 @@
1
1
  import sys
2
- from typing import (
3
- Any,
4
- Awaitable,
5
- Coroutine,
6
- Generator,
7
- Mapping,
8
- Type,
9
- TypeAlias,
10
- TypeGuard,
11
- TypeVar,
12
- )
2
+ from collections.abc import Awaitable, Coroutine, Generator
3
+ from typing import Any, Literal, TypeAlias, TypeGuard, TypeVar
13
4
 
14
5
  from pydantic import BaseModel
15
6
 
16
7
 
17
- def is_basemodel_subclass(model: Any) -> TypeGuard[Type[BaseModel]]:
8
+ def is_basemodel_subclass(model: Any) -> TypeGuard[type[BaseModel]]:
18
9
  try:
19
10
  return issubclass(model, BaseModel)
20
11
  except TypeError:
@@ -22,18 +13,21 @@ def is_basemodel_subclass(model: Any) -> TypeGuard[Type[BaseModel]]:
22
13
 
23
14
 
24
15
  class TaskIOValidator(BaseModel):
25
- workflow_input: Type[BaseModel] | None = None
26
- step_output: Type[BaseModel] | None = None
16
+ workflow_input: type[BaseModel] | None = None
17
+ step_output: type[BaseModel] | None = None
27
18
 
28
19
 
29
- JSONSerializableMapping = Mapping[str, Any]
20
+ JSONSerializableMapping = dict[str, Any]
30
21
 
31
22
 
32
23
  _T_co = TypeVar("_T_co", covariant=True)
33
24
 
34
25
  if sys.version_info >= (3, 12):
35
- AwaitableLike: TypeAlias = Awaitable[_T_co] # noqa: Y047
36
- CoroutineLike: TypeAlias = Coroutine[Any, Any, _T_co] # noqa: Y047
26
+ AwaitableLike: TypeAlias = Awaitable[_T_co]
27
+ CoroutineLike: TypeAlias = Coroutine[Any, Any, _T_co]
37
28
  else:
38
29
  AwaitableLike: TypeAlias = Generator[Any, None, _T_co] | Awaitable[_T_co]
39
30
  CoroutineLike: TypeAlias = Generator[Any, None, _T_co] | Coroutine[Any, Any, _T_co]
31
+
32
+ STOP_LOOP_TYPE = Literal["STOP_LOOP"]
33
+ STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" # Sentinel object to stop the loop
hatchet_sdk/waits.py CHANGED
@@ -6,6 +6,7 @@ from uuid import uuid4
6
6
 
7
7
  from pydantic import BaseModel, Field
8
8
 
9
+ from hatchet_sdk.config import ClientConfig
9
10
  from hatchet_sdk.contracts.v1.shared.condition_pb2 import Action as ProtoAction
10
11
  from hatchet_sdk.contracts.v1.shared.condition_pb2 import (
11
12
  BaseMatchCondition,
@@ -53,7 +54,7 @@ class Condition(ABC):
53
54
 
54
55
  @abstractmethod
55
56
  def to_proto(
56
- self,
57
+ self, config: ClientConfig
57
58
  ) -> UserEventMatchCondition | ParentOverrideMatchCondition | SleepMatchCondition:
58
59
  pass
59
60
 
@@ -71,7 +72,7 @@ class SleepCondition(Condition):
71
72
 
72
73
  self.duration = duration
73
74
 
74
- def to_proto(self) -> SleepMatchCondition:
75
+ def to_proto(self, config: ClientConfig) -> SleepMatchCondition:
75
76
  return SleepMatchCondition(
76
77
  base=self.base.to_proto(),
77
78
  sleep_for=timedelta_to_expr(self.duration),
@@ -95,10 +96,10 @@ class UserEventCondition(Condition):
95
96
  self.event_key = event_key
96
97
  self.expression = expression
97
98
 
98
- def to_proto(self) -> UserEventMatchCondition:
99
+ def to_proto(self, config: ClientConfig) -> UserEventMatchCondition:
99
100
  return UserEventMatchCondition(
100
101
  base=self.base.to_proto(),
101
- user_event_key=self.event_key,
102
+ user_event_key=config.apply_namespace(self.event_key),
102
103
  )
103
104
 
104
105
 
@@ -124,7 +125,7 @@ class ParentCondition(Condition):
124
125
 
125
126
  self.parent = parent
126
127
 
127
- def to_proto(self) -> ParentOverrideMatchCondition:
128
+ def to_proto(self, config: ClientConfig) -> ParentOverrideMatchCondition:
128
129
  return ParentOverrideMatchCondition(
129
130
  base=self.base.to_proto(),
130
131
  parent_readable_id=self.parent.name,
@@ -4,9 +4,10 @@ import signal
4
4
  import time
5
5
  from dataclasses import dataclass
6
6
  from multiprocessing import Queue
7
- from typing import Any, Literal
7
+ from typing import Any
8
8
 
9
9
  import grpc
10
+ from grpc.aio import UnaryUnaryCall
10
11
 
11
12
  from hatchet_sdk.client import Client
12
13
  from hatchet_sdk.clients.dispatcher.action_listener import (
@@ -19,6 +20,9 @@ from hatchet_sdk.config import ClientConfig
19
20
  from hatchet_sdk.contracts.dispatcher_pb2 import (
20
21
  GROUP_KEY_EVENT_TYPE_STARTED,
21
22
  STEP_EVENT_TYPE_STARTED,
23
+ ActionEventResponse,
24
+ GroupKeyActionEvent,
25
+ StepActionEvent,
22
26
  )
23
27
  from hatchet_sdk.logger import logger
24
28
  from hatchet_sdk.runnables.action import Action, ActionType
@@ -29,6 +33,7 @@ from hatchet_sdk.runnables.contextvars import (
29
33
  ctx_workflow_run_id,
30
34
  )
31
35
  from hatchet_sdk.utils.backoff import exp_backoff_sleep
36
+ from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE
32
37
 
33
38
  ACTION_EVENT_RETRY_COUNT = 5
34
39
 
@@ -41,9 +46,6 @@ class ActionEvent:
41
46
  should_not_retry: bool
42
47
 
43
48
 
44
- STOP_LOOP_TYPE = Literal["STOP_LOOP"]
45
- STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP" # Sentinel object to stop the loop
46
-
47
49
  BLOCKED_THREAD_WARNING = "THE TIME TO START THE TASK RUN IS TOO LONG, THE EVENT LOOP MAY BE BLOCKED. See https://docs.hatchet.run/blog/warning-event-loop-blocked for details and debugging help."
48
50
 
49
51
 
@@ -56,9 +58,9 @@ class WorkerActionListenerProcess:
56
58
  config: ClientConfig,
57
59
  action_queue: "Queue[Action]",
58
60
  event_queue: "Queue[ActionEvent | STOP_LOOP_TYPE]",
59
- handle_kill: bool = True,
60
- debug: bool = False,
61
- labels: dict[str, str | int] = {},
61
+ handle_kill: bool,
62
+ debug: bool,
63
+ labels: dict[str, str | int],
62
64
  ) -> None:
63
65
  self.name = name
64
66
  self.actions = actions
@@ -75,6 +77,14 @@ class WorkerActionListenerProcess:
75
77
  self.action_loop_task: asyncio.Task[None] | None = None
76
78
  self.event_send_loop_task: asyncio.Task[None] | None = None
77
79
  self.running_step_runs: dict[str, float] = {}
80
+ self.step_action_events: set[
81
+ asyncio.Task[UnaryUnaryCall[StepActionEvent, ActionEventResponse] | None]
82
+ ] = set()
83
+ self.group_key_action_events: set[
84
+ asyncio.Task[
85
+ UnaryUnaryCall[GroupKeyActionEvent, ActionEventResponse] | None
86
+ ]
87
+ ] = set()
78
88
 
79
89
  if self.debug:
80
90
  logger.setLevel(logging.DEBUG)
@@ -144,20 +154,21 @@ class WorkerActionListenerProcess:
144
154
  break
145
155
 
146
156
  logger.debug(f"tx: event: {event.action.action_id}/{event.type}")
147
- asyncio.create_task(self.send_event(event))
157
+ t = asyncio.create_task(self.send_event(event))
158
+ self.step_action_events.add(t)
159
+ t.add_done_callback(lambda t: self.step_action_events.discard(t))
148
160
 
149
161
  async def start_blocked_main_loop(self) -> None:
150
162
  threshold = 1
151
163
  while not self.killing:
152
164
  count = 0
153
- for _, start_time in self.running_step_runs.items():
165
+ for start_time in self.running_step_runs.values():
154
166
  diff = self.now() - start_time
155
167
  if diff > threshold:
156
168
  count += 1
157
169
 
158
170
  if count > 0:
159
171
  logger.warning(f"{BLOCKED_THREAD_WARNING}: Waiting Steps {count}")
160
- print(asyncio.current_task())
161
172
  await asyncio.sleep(1)
162
173
 
163
174
  async def send_event(self, event: ActionEvent, retry_attempt: int = 1) -> None:
@@ -187,7 +198,7 @@ class WorkerActionListenerProcess:
187
198
  self.now()
188
199
  )
189
200
 
190
- asyncio.create_task(
201
+ send_started_event_task = asyncio.create_task(
191
202
  self.dispatcher_client.send_step_action_event(
192
203
  event.action,
193
204
  event.type,
@@ -195,14 +206,23 @@ class WorkerActionListenerProcess:
195
206
  event.should_not_retry,
196
207
  )
197
208
  )
209
+
210
+ self.step_action_events.add(send_started_event_task)
211
+ send_started_event_task.add_done_callback(
212
+ lambda t: self.step_action_events.discard(t)
213
+ )
198
214
  case ActionType.CANCEL_STEP_RUN:
199
215
  logger.debug("unimplemented event send")
200
216
  case ActionType.START_GET_GROUP_KEY:
201
- asyncio.create_task(
217
+ get_group_key_task = asyncio.create_task(
202
218
  self.dispatcher_client.send_group_key_action_event(
203
219
  event.action, event.type, event.payload
204
220
  )
205
221
  )
222
+ self.group_key_action_events.add(get_group_key_task)
223
+ get_group_key_task.add_done_callback(
224
+ lambda t: self.group_key_action_events.discard(t)
225
+ )
206
226
  case _:
207
227
  logger.error("unknown action type for event send")
208
228
  except Exception as e:
@@ -317,7 +337,7 @@ def worker_action_listener_process(*args: Any, **kwargs: Any) -> None:
317
337
  process = WorkerActionListenerProcess(*args, **kwargs)
318
338
  await process.start()
319
339
  # Keep the process running
320
- while not process.killing:
340
+ while not process.killing: # noqa: ASYNC110
321
341
  await asyncio.sleep(0.1)
322
342
 
323
343
  asyncio.run(run())
@@ -1,19 +1,17 @@
1
1
  import asyncio
2
2
  import logging
3
3
  from multiprocessing import Queue
4
- from typing import Any, Literal, TypeVar
4
+ from typing import Any, TypeVar
5
5
 
6
6
  from hatchet_sdk.client import Client
7
7
  from hatchet_sdk.config import ClientConfig
8
8
  from hatchet_sdk.logger import logger
9
9
  from hatchet_sdk.runnables.action import Action
10
10
  from hatchet_sdk.runnables.task import Task
11
+ from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE
11
12
  from hatchet_sdk.worker.action_listener_process import ActionEvent
12
13
  from hatchet_sdk.worker.runner.runner import Runner
13
- from hatchet_sdk.worker.runner.utils.capture_logs import capture_logs
14
-
15
- STOP_LOOP_TYPE = Literal["STOP_LOOP"]
16
- STOP_LOOP: STOP_LOOP_TYPE = "STOP_LOOP"
14
+ from hatchet_sdk.worker.runner.utils.capture_logs import AsyncLogSender, capture_logs
17
15
 
18
16
  T = TypeVar("T")
19
17
 
@@ -28,10 +26,10 @@ class WorkerActionRunLoopManager:
28
26
  action_queue: "Queue[Action | STOP_LOOP_TYPE]",
29
27
  event_queue: "Queue[ActionEvent]",
30
28
  loop: asyncio.AbstractEventLoop,
31
- handle_kill: bool = True,
32
- debug: bool = False,
33
- labels: dict[str, str | int] = {},
34
- lifespan_context: Any | None = None,
29
+ handle_kill: bool,
30
+ debug: bool,
31
+ labels: dict[str, str | int] | None,
32
+ lifespan_context: Any | None,
35
33
  ) -> None:
36
34
  self.name = name
37
35
  self.action_registry = action_registry
@@ -52,15 +50,19 @@ class WorkerActionRunLoopManager:
52
50
  self.runner: Runner | None = None
53
51
 
54
52
  self.client = Client(config=self.config, debug=self.debug)
53
+ self.start_loop_manager_task: asyncio.Task[None] | None = None
54
+ self.log_sender = AsyncLogSender(self.client.event)
55
+ self.log_task = self.loop.create_task(self.log_sender.consume())
56
+
55
57
  self.start()
56
58
 
57
59
  def start(self) -> None:
58
- k = self.loop.create_task(self.aio_start()) # noqa: F841
60
+ self.start_loop_manager_task = self.loop.create_task(self.aio_start())
59
61
 
60
62
  async def aio_start(self, retry_count: int = 1) -> None:
61
63
  await capture_logs(
62
64
  self.client.log_interceptor,
63
- self.client.event,
65
+ self.log_sender,
64
66
  self._async_start,
65
67
  )()
66
68
 
@@ -75,6 +77,7 @@ class WorkerActionRunLoopManager:
75
77
  self.killing = True
76
78
 
77
79
  self.action_queue.put(STOP_LOOP)
80
+ self.log_sender.publish(STOP_LOOP)
78
81
 
79
82
  async def wait_for_tasks(self) -> None:
80
83
  if self.runner:
@@ -89,6 +92,7 @@ class WorkerActionRunLoopManager:
89
92
  self.action_registry,
90
93
  self.labels,
91
94
  self.lifespan_context,
95
+ self.log_sender,
92
96
  )
93
97
 
94
98
  logger.debug(f"'{self.name}' waiting for {list(self.action_registry.keys())}")
@@ -1,14 +1,14 @@
1
1
  import asyncio
2
- import contextvars
3
2
  import ctypes
4
3
  import functools
5
4
  import json
6
- import traceback
5
+ from collections.abc import Callable
7
6
  from concurrent.futures import ThreadPoolExecutor
7
+ from contextlib import suppress
8
8
  from enum import Enum
9
9
  from multiprocessing import Queue
10
10
  from threading import Thread, current_thread
11
- from typing import Any, Callable, Dict, Literal, cast, overload
11
+ from typing import Any, Literal, cast, overload
12
12
 
13
13
  from pydantic import BaseModel
14
14
 
@@ -30,7 +30,7 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
30
30
  STEP_EVENT_TYPE_FAILED,
31
31
  STEP_EVENT_TYPE_STARTED,
32
32
  )
33
- from hatchet_sdk.exceptions import NonRetryableException
33
+ from hatchet_sdk.exceptions import NonRetryableException, TaskRunError
34
34
  from hatchet_sdk.features.runs import RunsClient
35
35
  from hatchet_sdk.logger import logger
36
36
  from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
@@ -40,12 +40,17 @@ from hatchet_sdk.runnables.contextvars import (
40
40
  ctx_worker_id,
41
41
  ctx_workflow_run_id,
42
42
  spawn_index_lock,
43
+ task_count,
43
44
  workflow_spawn_indices,
44
45
  )
45
46
  from hatchet_sdk.runnables.task import Task
46
47
  from hatchet_sdk.runnables.types import R, TWorkflowInput
47
48
  from hatchet_sdk.worker.action_listener_process import ActionEvent
48
- from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars
49
+ from hatchet_sdk.worker.runner.utils.capture_logs import (
50
+ AsyncLogSender,
51
+ ContextVarToCopy,
52
+ copy_context_vars,
53
+ )
49
54
 
50
55
 
51
56
  class WorkerStatus(Enum):
@@ -61,10 +66,11 @@ class Runner:
61
66
  event_queue: "Queue[ActionEvent]",
62
67
  config: ClientConfig,
63
68
  slots: int,
64
- handle_kill: bool = True,
65
- action_registry: dict[str, Task[TWorkflowInput, R]] = {},
66
- labels: dict[str, str | int] = {},
67
- lifespan_context: Any | None = None,
69
+ handle_kill: bool,
70
+ action_registry: dict[str, Task[TWorkflowInput, R]],
71
+ labels: dict[str, str | int] | None,
72
+ lifespan_context: Any | None,
73
+ log_sender: AsyncLogSender,
68
74
  ):
69
75
  # We store the config so we can dynamically create clients for the dispatcher client.
70
76
  self.config = config
@@ -72,13 +78,14 @@ class Runner:
72
78
  self.slots = slots
73
79
  self.tasks: dict[ActionKey, asyncio.Task[Any]] = {} # Store run ids and futures
74
80
  self.contexts: dict[ActionKey, Context] = {} # Store run ids and contexts
75
- self.action_registry = action_registry
81
+ self.action_registry = action_registry or {}
76
82
 
77
83
  self.event_queue = event_queue
78
84
 
79
85
  # The thread pool is used for synchronous functions which need to run concurrently
80
86
  self.thread_pool = ThreadPoolExecutor(max_workers=slots)
81
- self.threads: Dict[ActionKey, Thread] = {} # Store run ids and threads
87
+ self.threads: dict[ActionKey, Thread] = {} # Store run ids and threads
88
+ self.running_tasks = set[asyncio.Task[Exception | None]]()
82
89
 
83
90
  self.killing = False
84
91
  self.handle_kill = handle_kill
@@ -101,10 +108,11 @@ class Runner:
101
108
  self.durable_event_listener = DurableEventListener(self.config)
102
109
 
103
110
  self.worker_context = WorkerContext(
104
- labels=labels, client=Client(config=config).dispatcher
111
+ labels=labels or {}, client=Client(config=config).dispatcher
105
112
  )
106
113
 
107
114
  self.lifespan_context = lifespan_context
115
+ self.log_sender = log_sender
108
116
 
109
117
  if self.config.enable_thread_pool_monitoring:
110
118
  self.start_background_monitoring()
@@ -116,67 +124,68 @@ class Runner:
116
124
  if self.worker_context.id() is None:
117
125
  self.worker_context._worker_id = action.worker_id
118
126
 
127
+ t: asyncio.Task[Exception | None] | None = None
119
128
  match action.action_type:
120
129
  case ActionType.START_STEP_RUN:
121
130
  log = f"run: start step: {action.action_id}/{action.step_run_id}"
122
131
  logger.info(log)
123
- asyncio.create_task(self.handle_start_step_run(action))
132
+ t = asyncio.create_task(self.handle_start_step_run(action))
124
133
  case ActionType.CANCEL_STEP_RUN:
125
134
  log = f"cancel: step run: {action.action_id}/{action.step_run_id}/{action.retry_count}"
126
135
  logger.info(log)
127
- asyncio.create_task(self.handle_cancel_action(action))
136
+ t = asyncio.create_task(self.handle_cancel_action(action))
128
137
  case ActionType.START_GET_GROUP_KEY:
129
138
  log = f"run: get group key: {action.action_id}/{action.get_group_key_run_id}"
130
139
  logger.info(log)
131
- asyncio.create_task(self.handle_start_group_key_run(action))
140
+ t = asyncio.create_task(self.handle_start_group_key_run(action))
132
141
  case _:
133
142
  log = f"unknown action type: {action.action_type}"
134
143
  logger.error(log)
135
144
 
145
+ if t is not None:
146
+ self.running_tasks.add(t)
147
+ t.add_done_callback(lambda task: self.running_tasks.discard(task))
148
+
136
149
  def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
137
150
  def inner_callback(task: asyncio.Task[Any]) -> None:
138
151
  self.cleanup_run_id(action.key)
139
152
 
140
- errored = False
141
- cancelled = task.cancelled()
142
- output = None
153
+ if task.cancelled():
154
+ return
143
155
 
144
- # Get the output from the future
145
156
  try:
146
- if not cancelled:
147
- output = task.result()
157
+ output = task.result()
148
158
  except Exception as e:
149
- errored = True
150
-
151
159
  should_not_retry = isinstance(e, NonRetryableException)
152
160
 
161
+ exc = TaskRunError.from_exception(e)
162
+
153
163
  # This except is coming from the application itself, so we want to send that to the Hatchet instance
154
164
  self.event_queue.put(
155
165
  ActionEvent(
156
166
  action=action,
157
167
  type=STEP_EVENT_TYPE_FAILED,
158
- payload=str(pretty_format_exception(f"{e}", e)),
168
+ payload=exc.serialize(),
159
169
  should_not_retry=should_not_retry,
160
170
  )
161
171
  )
162
172
 
163
173
  logger.error(
164
- f"failed step run: {action.action_id}/{action.step_run_id}"
174
+ f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
165
175
  )
166
176
 
167
- if not errored and not cancelled:
168
- self.event_queue.put(
169
- ActionEvent(
170
- action=action,
171
- type=STEP_EVENT_TYPE_COMPLETED,
172
- payload=self.serialize_output(output),
173
- should_not_retry=False,
174
- )
175
- )
177
+ return
176
178
 
177
- logger.info(
178
- f"finished step run: {action.action_id}/{action.step_run_id}"
179
+ self.event_queue.put(
180
+ ActionEvent(
181
+ action=action,
182
+ type=STEP_EVENT_TYPE_COMPLETED,
183
+ payload=self.serialize_output(output),
184
+ should_not_retry=False,
179
185
  )
186
+ )
187
+
188
+ logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
180
189
 
181
190
  return inner_callback
182
191
 
@@ -186,51 +195,46 @@ class Runner:
186
195
  def inner_callback(task: asyncio.Task[Any]) -> None:
187
196
  self.cleanup_run_id(action.key)
188
197
 
189
- errored = False
190
- cancelled = task.cancelled()
191
- output = None
198
+ if task.cancelled():
199
+ return
192
200
 
193
- # Get the output from the future
194
201
  try:
195
- if not cancelled:
196
- output = task.result()
202
+ output = task.result()
197
203
  except Exception as e:
198
- errored = True
204
+ exc = TaskRunError.from_exception(e)
205
+
199
206
  self.event_queue.put(
200
207
  ActionEvent(
201
208
  action=action,
202
209
  type=GROUP_KEY_EVENT_TYPE_FAILED,
203
- payload=str(pretty_format_exception(f"{e}", e)),
210
+ payload=exc.serialize(),
204
211
  should_not_retry=False,
205
212
  )
206
213
  )
207
214
 
208
215
  logger.error(
209
- f"failed step run: {action.action_id}/{action.step_run_id}"
216
+ f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
210
217
  )
211
218
 
212
- if not errored and not cancelled:
213
- self.event_queue.put(
214
- ActionEvent(
215
- action=action,
216
- type=GROUP_KEY_EVENT_TYPE_COMPLETED,
217
- payload=self.serialize_output(output),
218
- should_not_retry=False,
219
- )
220
- )
219
+ return
221
220
 
222
- logger.info(
223
- f"finished step run: {action.action_id}/{action.step_run_id}"
221
+ self.event_queue.put(
222
+ ActionEvent(
223
+ action=action,
224
+ type=GROUP_KEY_EVENT_TYPE_COMPLETED,
225
+ payload=self.serialize_output(output),
226
+ should_not_retry=False,
224
227
  )
228
+ )
229
+
230
+ logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
225
231
 
226
232
  return inner_callback
227
233
 
228
234
  def thread_action_func(
229
235
  self, ctx: Context, task: Task[TWorkflowInput, R], action: Action
230
236
  ) -> R:
231
- if action.step_run_id:
232
- self.threads[action.key] = current_thread()
233
- elif action.get_group_key_run_id:
237
+ if action.step_run_id or action.get_group_key_run_id:
234
238
  self.threads[action.key] = current_thread()
235
239
 
236
240
  return task.call(ctx)
@@ -250,28 +254,36 @@ class Runner:
250
254
  try:
251
255
  if task.is_async_function:
252
256
  return await task.aio_call(ctx)
253
- else:
254
- pfunc = functools.partial(
255
- # we must copy the context vars to the new thread, as only asyncio natively supports
256
- # contextvars
257
- copy_context_vars,
258
- contextvars.copy_context().items(),
259
- self.thread_action_func,
260
- ctx,
261
- task,
262
- action,
263
- )
264
-
265
- loop = asyncio.get_event_loop()
266
- return await loop.run_in_executor(self.thread_pool, pfunc)
267
- except Exception as e:
268
- logger.error(
269
- pretty_format_exception(
270
- f"exception raised in action ({action.action_id}, retry={action.retry_count}):\n{e}",
271
- e,
272
- )
257
+ pfunc = functools.partial(
258
+ # we must copy the context vars to the new thread, as only asyncio natively supports
259
+ # contextvars
260
+ copy_context_vars,
261
+ [
262
+ ContextVarToCopy(
263
+ name="ctx_step_run_id",
264
+ value=action.step_run_id,
265
+ ),
266
+ ContextVarToCopy(
267
+ name="ctx_workflow_run_id",
268
+ value=action.workflow_run_id,
269
+ ),
270
+ ContextVarToCopy(
271
+ name="ctx_worker_id",
272
+ value=action.worker_id,
273
+ ),
274
+ ContextVarToCopy(
275
+ name="ctx_action_key",
276
+ value=action.key,
277
+ ),
278
+ ],
279
+ self.thread_action_func,
280
+ ctx,
281
+ task,
282
+ action,
273
283
  )
274
- raise e
284
+
285
+ loop = asyncio.get_event_loop()
286
+ return await loop.run_in_executor(self.thread_pool, pfunc)
275
287
  finally:
276
288
  self.cleanup_run_id(action.key)
277
289
 
@@ -295,7 +307,7 @@ class Runner:
295
307
  while True:
296
308
  await self.log_thread_pool_status()
297
309
 
298
- for key in self.threads.keys():
310
+ for key in self.threads:
299
311
  if key not in self.tasks:
300
312
  logger.debug(f"Potential zombie thread found for key {key}")
301
313
 
@@ -350,6 +362,7 @@ class Runner:
350
362
  worker=self.worker_context,
351
363
  runs_client=self.runs_client,
352
364
  lifespan_context=self.lifespan_context,
365
+ log_sender=self.log_sender,
353
366
  )
354
367
 
355
368
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@@ -361,7 +374,8 @@ class Runner:
361
374
 
362
375
  if action_func:
363
376
  context = self.create_context(
364
- action, True if action_func.is_durable else False
377
+ action,
378
+ True if action_func.is_durable else False, # noqa: SIM210
365
379
  )
366
380
 
367
381
  self.contexts[action.key] = context
@@ -382,11 +396,12 @@ class Runner:
382
396
  task.add_done_callback(self.step_run_callback(action))
383
397
  self.tasks[action.key] = task
384
398
 
385
- try:
399
+ task_count.increment()
400
+
401
+ ## FIXME: Handle cancelled exceptions and other special exceptions
402
+ ## that we don't want to suppress here
403
+ with suppress(Exception):
386
404
  await task
387
- except Exception:
388
- # do nothing, this should be caught in the callback
389
- pass
390
405
 
391
406
  ## Once the step run completes, we need to remove the workflow spawn index
392
407
  ## so we don't leak memory
@@ -444,7 +459,7 @@ class Runner:
444
459
  res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
445
460
  if res == 0:
446
461
  raise ValueError("Invalid thread ID")
447
- elif res != 1:
462
+ if res != 1:
448
463
  logger.error("PyThreadState_SetAsyncExc failed")
449
464
 
450
465
  # Call with exception set to 0 is needed to cleanup properly.
@@ -505,8 +520,3 @@ class Runner:
505
520
  logger.info(f"waiting for {running} tasks to finish...")
506
521
  await asyncio.sleep(1)
507
522
  running = len(self.tasks.keys())
508
-
509
-
510
- def pretty_format_exception(message: str, e: Exception) -> str:
511
- trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
512
- return f"{message}\n{trace}"