hatchet-sdk 1.12.3__py3-none-any.whl → 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of hatchet-sdk might be problematic. Click here for more details.

Files changed (83) hide show
  1. hatchet_sdk/__init__.py +54 -40
  2. hatchet_sdk/clients/admin.py +18 -23
  3. hatchet_sdk/clients/dispatcher/action_listener.py +4 -3
  4. hatchet_sdk/clients/dispatcher/dispatcher.py +1 -4
  5. hatchet_sdk/clients/event_ts.py +2 -1
  6. hatchet_sdk/clients/events.py +16 -12
  7. hatchet_sdk/clients/listeners/durable_event_listener.py +4 -2
  8. hatchet_sdk/clients/listeners/pooled_listener.py +2 -2
  9. hatchet_sdk/clients/listeners/run_event_listener.py +7 -8
  10. hatchet_sdk/clients/listeners/workflow_listener.py +14 -6
  11. hatchet_sdk/clients/rest/api_response.py +3 -2
  12. hatchet_sdk/clients/rest/models/semaphore_slots.py +1 -1
  13. hatchet_sdk/clients/rest/models/v1_task_summary.py +5 -0
  14. hatchet_sdk/clients/rest/models/v1_workflow_run_details.py +11 -1
  15. hatchet_sdk/clients/rest/models/workflow_version.py +5 -0
  16. hatchet_sdk/clients/rest/tenacity_utils.py +6 -8
  17. hatchet_sdk/config.py +2 -0
  18. hatchet_sdk/connection.py +10 -4
  19. hatchet_sdk/context/context.py +170 -46
  20. hatchet_sdk/context/worker_context.py +4 -7
  21. hatchet_sdk/contracts/dispatcher_pb2.py +38 -38
  22. hatchet_sdk/contracts/dispatcher_pb2.pyi +4 -2
  23. hatchet_sdk/contracts/events_pb2.py +13 -13
  24. hatchet_sdk/contracts/events_pb2.pyi +4 -2
  25. hatchet_sdk/contracts/v1/workflows_pb2.py +1 -1
  26. hatchet_sdk/contracts/v1/workflows_pb2.pyi +2 -2
  27. hatchet_sdk/exceptions.py +103 -1
  28. hatchet_sdk/features/cron.py +2 -2
  29. hatchet_sdk/features/filters.py +12 -21
  30. hatchet_sdk/features/runs.py +4 -4
  31. hatchet_sdk/features/scheduled.py +8 -9
  32. hatchet_sdk/hatchet.py +65 -64
  33. hatchet_sdk/opentelemetry/instrumentor.py +20 -20
  34. hatchet_sdk/runnables/action.py +1 -2
  35. hatchet_sdk/runnables/contextvars.py +19 -0
  36. hatchet_sdk/runnables/task.py +37 -29
  37. hatchet_sdk/runnables/types.py +9 -8
  38. hatchet_sdk/runnables/workflow.py +57 -42
  39. hatchet_sdk/utils/proto_enums.py +4 -4
  40. hatchet_sdk/utils/timedelta_to_expression.py +2 -3
  41. hatchet_sdk/utils/typing.py +11 -17
  42. hatchet_sdk/v0/__init__.py +7 -7
  43. hatchet_sdk/v0/clients/admin.py +7 -7
  44. hatchet_sdk/v0/clients/dispatcher/action_listener.py +8 -8
  45. hatchet_sdk/v0/clients/dispatcher/dispatcher.py +9 -9
  46. hatchet_sdk/v0/clients/events.py +3 -3
  47. hatchet_sdk/v0/clients/rest/tenacity_utils.py +1 -1
  48. hatchet_sdk/v0/clients/run_event_listener.py +3 -3
  49. hatchet_sdk/v0/clients/workflow_listener.py +5 -5
  50. hatchet_sdk/v0/context/context.py +6 -6
  51. hatchet_sdk/v0/hatchet.py +4 -4
  52. hatchet_sdk/v0/opentelemetry/instrumentor.py +1 -1
  53. hatchet_sdk/v0/rate_limit.py +1 -1
  54. hatchet_sdk/v0/v2/callable.py +4 -4
  55. hatchet_sdk/v0/v2/concurrency.py +2 -2
  56. hatchet_sdk/v0/v2/hatchet.py +3 -3
  57. hatchet_sdk/v0/worker/action_listener_process.py +6 -6
  58. hatchet_sdk/v0/worker/runner/run_loop_manager.py +1 -1
  59. hatchet_sdk/v0/worker/runner/runner.py +10 -10
  60. hatchet_sdk/v0/worker/runner/utils/capture_logs.py +1 -1
  61. hatchet_sdk/v0/worker/worker.py +2 -2
  62. hatchet_sdk/v0/workflow.py +3 -3
  63. hatchet_sdk/waits.py +6 -5
  64. hatchet_sdk/worker/action_listener_process.py +33 -13
  65. hatchet_sdk/worker/runner/run_loop_manager.py +15 -11
  66. hatchet_sdk/worker/runner/runner.py +142 -80
  67. hatchet_sdk/worker/runner/utils/capture_logs.py +72 -31
  68. hatchet_sdk/worker/worker.py +30 -26
  69. hatchet_sdk/workflow_run.py +4 -2
  70. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.14.0.dist-info}/METADATA +1 -1
  71. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.14.0.dist-info}/RECORD +73 -83
  72. hatchet_sdk/v0/contracts/dispatcher_pb2.py +0 -102
  73. hatchet_sdk/v0/contracts/dispatcher_pb2.pyi +0 -387
  74. hatchet_sdk/v0/contracts/dispatcher_pb2_grpc.py +0 -621
  75. hatchet_sdk/v0/contracts/events_pb2.py +0 -46
  76. hatchet_sdk/v0/contracts/events_pb2.pyi +0 -87
  77. hatchet_sdk/v0/contracts/events_pb2_grpc.py +0 -274
  78. hatchet_sdk/v0/contracts/workflows_pb2.py +0 -80
  79. hatchet_sdk/v0/contracts/workflows_pb2.pyi +0 -312
  80. hatchet_sdk/v0/contracts/workflows_pb2_grpc.py +0 -277
  81. hatchet_sdk/v0/logger.py +0 -13
  82. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.14.0.dist-info}/WHEEL +0 -0
  83. {hatchet_sdk-1.12.3.dist-info → hatchet_sdk-1.14.0.dist-info}/entry_points.txt +0 -0
@@ -1,14 +1,14 @@
1
1
  import asyncio
2
- import contextvars
3
2
  import ctypes
4
3
  import functools
5
4
  import json
6
- import traceback
5
+ from collections.abc import Callable
7
6
  from concurrent.futures import ThreadPoolExecutor
7
+ from contextlib import suppress
8
8
  from enum import Enum
9
9
  from multiprocessing import Queue
10
10
  from threading import Thread, current_thread
11
- from typing import Any, Callable, Dict, Literal, cast, overload
11
+ from typing import Any, Literal, cast, overload
12
12
 
13
13
  from pydantic import BaseModel
14
14
 
@@ -30,7 +30,11 @@ from hatchet_sdk.contracts.dispatcher_pb2 import (
30
30
  STEP_EVENT_TYPE_FAILED,
31
31
  STEP_EVENT_TYPE_STARTED,
32
32
  )
33
- from hatchet_sdk.exceptions import NonRetryableException
33
+ from hatchet_sdk.exceptions import (
34
+ IllegalTaskOutputError,
35
+ NonRetryableException,
36
+ TaskRunError,
37
+ )
34
38
  from hatchet_sdk.features.runs import RunsClient
35
39
  from hatchet_sdk.logger import logger
36
40
  from hatchet_sdk.runnables.action import Action, ActionKey, ActionType
@@ -40,12 +44,17 @@ from hatchet_sdk.runnables.contextvars import (
40
44
  ctx_worker_id,
41
45
  ctx_workflow_run_id,
42
46
  spawn_index_lock,
47
+ task_count,
43
48
  workflow_spawn_indices,
44
49
  )
45
50
  from hatchet_sdk.runnables.task import Task
46
51
  from hatchet_sdk.runnables.types import R, TWorkflowInput
47
52
  from hatchet_sdk.worker.action_listener_process import ActionEvent
48
- from hatchet_sdk.worker.runner.utils.capture_logs import copy_context_vars
53
+ from hatchet_sdk.worker.runner.utils.capture_logs import (
54
+ AsyncLogSender,
55
+ ContextVarToCopy,
56
+ copy_context_vars,
57
+ )
49
58
 
50
59
 
51
60
  class WorkerStatus(Enum):
@@ -61,10 +70,11 @@ class Runner:
61
70
  event_queue: "Queue[ActionEvent]",
62
71
  config: ClientConfig,
63
72
  slots: int,
64
- handle_kill: bool = True,
65
- action_registry: dict[str, Task[TWorkflowInput, R]] = {},
66
- labels: dict[str, str | int] = {},
67
- lifespan_context: Any | None = None,
73
+ handle_kill: bool,
74
+ action_registry: dict[str, Task[TWorkflowInput, R]],
75
+ labels: dict[str, str | int] | None,
76
+ lifespan_context: Any | None,
77
+ log_sender: AsyncLogSender,
68
78
  ):
69
79
  # We store the config so we can dynamically create clients for the dispatcher client.
70
80
  self.config = config
@@ -72,13 +82,14 @@ class Runner:
72
82
  self.slots = slots
73
83
  self.tasks: dict[ActionKey, asyncio.Task[Any]] = {} # Store run ids and futures
74
84
  self.contexts: dict[ActionKey, Context] = {} # Store run ids and contexts
75
- self.action_registry = action_registry
85
+ self.action_registry = action_registry or {}
76
86
 
77
87
  self.event_queue = event_queue
78
88
 
79
89
  # The thread pool is used for synchronous functions which need to run concurrently
80
90
  self.thread_pool = ThreadPoolExecutor(max_workers=slots)
81
- self.threads: Dict[ActionKey, Thread] = {} # Store run ids and threads
91
+ self.threads: dict[ActionKey, Thread] = {} # Store run ids and threads
92
+ self.running_tasks = set[asyncio.Task[Exception | None]]()
82
93
 
83
94
  self.killing = False
84
95
  self.handle_kill = handle_kill
@@ -101,10 +112,11 @@ class Runner:
101
112
  self.durable_event_listener = DurableEventListener(self.config)
102
113
 
103
114
  self.worker_context = WorkerContext(
104
- labels=labels, client=Client(config=config).dispatcher
115
+ labels=labels or {}, client=Client(config=config).dispatcher
105
116
  )
106
117
 
107
118
  self.lifespan_context = lifespan_context
119
+ self.log_sender = log_sender
108
120
 
109
121
  if self.config.enable_thread_pool_monitoring:
110
122
  self.start_background_monitoring()
@@ -116,68 +128,90 @@ class Runner:
116
128
  if self.worker_context.id() is None:
117
129
  self.worker_context._worker_id = action.worker_id
118
130
 
131
+ t: asyncio.Task[Exception | None] | None = None
119
132
  match action.action_type:
120
133
  case ActionType.START_STEP_RUN:
121
134
  log = f"run: start step: {action.action_id}/{action.step_run_id}"
122
135
  logger.info(log)
123
- asyncio.create_task(self.handle_start_step_run(action))
136
+ t = asyncio.create_task(self.handle_start_step_run(action))
124
137
  case ActionType.CANCEL_STEP_RUN:
125
138
  log = f"cancel: step run: {action.action_id}/{action.step_run_id}/{action.retry_count}"
126
139
  logger.info(log)
127
- asyncio.create_task(self.handle_cancel_action(action))
140
+ t = asyncio.create_task(self.handle_cancel_action(action))
128
141
  case ActionType.START_GET_GROUP_KEY:
129
142
  log = f"run: get group key: {action.action_id}/{action.get_group_key_run_id}"
130
143
  logger.info(log)
131
- asyncio.create_task(self.handle_start_group_key_run(action))
144
+ t = asyncio.create_task(self.handle_start_group_key_run(action))
132
145
  case _:
133
146
  log = f"unknown action type: {action.action_type}"
134
147
  logger.error(log)
135
148
 
149
+ if t is not None:
150
+ self.running_tasks.add(t)
151
+ t.add_done_callback(lambda task: self.running_tasks.discard(task))
152
+
136
153
  def step_run_callback(self, action: Action) -> Callable[[asyncio.Task[Any]], None]:
137
154
  def inner_callback(task: asyncio.Task[Any]) -> None:
138
155
  self.cleanup_run_id(action.key)
139
156
 
140
- errored = False
141
- cancelled = task.cancelled()
142
- output = None
157
+ if task.cancelled():
158
+ return
143
159
 
144
- # Get the output from the future
145
160
  try:
146
- if not cancelled:
147
- output = task.result()
161
+ output = task.result()
148
162
  except Exception as e:
149
- errored = True
150
-
151
163
  should_not_retry = isinstance(e, NonRetryableException)
152
164
 
165
+ exc = TaskRunError.from_exception(e)
166
+
153
167
  # This except is coming from the application itself, so we want to send that to the Hatchet instance
154
168
  self.event_queue.put(
155
169
  ActionEvent(
156
170
  action=action,
157
171
  type=STEP_EVENT_TYPE_FAILED,
158
- payload=str(pretty_format_exception(f"{e}", e)),
172
+ payload=exc.serialize(),
159
173
  should_not_retry=should_not_retry,
160
174
  )
161
175
  )
162
176
 
163
- logger.error(
164
- f"failed step run: {action.action_id}/{action.step_run_id}"
177
+ log_with_level = logger.info if should_not_retry else logger.error
178
+
179
+ log_with_level(
180
+ f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
165
181
  )
166
182
 
167
- if not errored and not cancelled:
183
+ return
184
+
185
+ try:
186
+ output = self.serialize_output(output)
187
+
168
188
  self.event_queue.put(
169
189
  ActionEvent(
170
190
  action=action,
171
191
  type=STEP_EVENT_TYPE_COMPLETED,
172
- payload=self.serialize_output(output),
192
+ payload=output,
193
+ should_not_retry=False,
194
+ )
195
+ )
196
+ except IllegalTaskOutputError as e:
197
+ exc = TaskRunError.from_exception(e)
198
+ self.event_queue.put(
199
+ ActionEvent(
200
+ action=action,
201
+ type=STEP_EVENT_TYPE_FAILED,
202
+ payload=exc.serialize(),
173
203
  should_not_retry=False,
174
204
  )
175
205
  )
176
206
 
177
- logger.info(
178
- f"finished step run: {action.action_id}/{action.step_run_id}"
207
+ logger.error(
208
+ f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
179
209
  )
180
210
 
211
+ return
212
+
213
+ logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
214
+
181
215
  return inner_callback
182
216
 
183
217
  def group_key_run_callback(
@@ -186,51 +220,65 @@ class Runner:
186
220
  def inner_callback(task: asyncio.Task[Any]) -> None:
187
221
  self.cleanup_run_id(action.key)
188
222
 
189
- errored = False
190
- cancelled = task.cancelled()
191
- output = None
223
+ if task.cancelled():
224
+ return
192
225
 
193
- # Get the output from the future
194
226
  try:
195
- if not cancelled:
196
- output = task.result()
227
+ output = task.result()
197
228
  except Exception as e:
198
- errored = True
229
+ exc = TaskRunError.from_exception(e)
230
+
199
231
  self.event_queue.put(
200
232
  ActionEvent(
201
233
  action=action,
202
234
  type=GROUP_KEY_EVENT_TYPE_FAILED,
203
- payload=str(pretty_format_exception(f"{e}", e)),
235
+ payload=exc.serialize(),
204
236
  should_not_retry=False,
205
237
  )
206
238
  )
207
239
 
208
240
  logger.error(
209
- f"failed step run: {action.action_id}/{action.step_run_id}"
241
+ f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
210
242
  )
211
243
 
212
- if not errored and not cancelled:
244
+ return
245
+
246
+ try:
247
+ output = self.serialize_output(output)
248
+
213
249
  self.event_queue.put(
214
250
  ActionEvent(
215
251
  action=action,
216
252
  type=GROUP_KEY_EVENT_TYPE_COMPLETED,
217
- payload=self.serialize_output(output),
253
+ payload=output,
254
+ should_not_retry=False,
255
+ )
256
+ )
257
+ except IllegalTaskOutputError as e:
258
+ exc = TaskRunError.from_exception(e)
259
+ self.event_queue.put(
260
+ ActionEvent(
261
+ action=action,
262
+ type=STEP_EVENT_TYPE_FAILED,
263
+ payload=exc.serialize(),
218
264
  should_not_retry=False,
219
265
  )
220
266
  )
221
267
 
222
- logger.info(
223
- f"finished step run: {action.action_id}/{action.step_run_id}"
268
+ logger.error(
269
+ f"failed step run: {action.action_id}/{action.step_run_id}\n{exc.serialize()}"
224
270
  )
225
271
 
272
+ return
273
+
274
+ logger.info(f"finished step run: {action.action_id}/{action.step_run_id}")
275
+
226
276
  return inner_callback
227
277
 
228
278
  def thread_action_func(
229
279
  self, ctx: Context, task: Task[TWorkflowInput, R], action: Action
230
280
  ) -> R:
231
- if action.step_run_id:
232
- self.threads[action.key] = current_thread()
233
- elif action.get_group_key_run_id:
281
+ if action.step_run_id or action.get_group_key_run_id:
234
282
  self.threads[action.key] = current_thread()
235
283
 
236
284
  return task.call(ctx)
@@ -250,28 +298,36 @@ class Runner:
250
298
  try:
251
299
  if task.is_async_function:
252
300
  return await task.aio_call(ctx)
253
- else:
254
- pfunc = functools.partial(
255
- # we must copy the context vars to the new thread, as only asyncio natively supports
256
- # contextvars
257
- copy_context_vars,
258
- contextvars.copy_context().items(),
259
- self.thread_action_func,
260
- ctx,
261
- task,
262
- action,
263
- )
264
-
265
- loop = asyncio.get_event_loop()
266
- return await loop.run_in_executor(self.thread_pool, pfunc)
267
- except Exception as e:
268
- logger.error(
269
- pretty_format_exception(
270
- f"exception raised in action ({action.action_id}, retry={action.retry_count}):\n{e}",
271
- e,
272
- )
301
+ pfunc = functools.partial(
302
+ # we must copy the context vars to the new thread, as only asyncio natively supports
303
+ # contextvars
304
+ copy_context_vars,
305
+ [
306
+ ContextVarToCopy(
307
+ name="ctx_step_run_id",
308
+ value=action.step_run_id,
309
+ ),
310
+ ContextVarToCopy(
311
+ name="ctx_workflow_run_id",
312
+ value=action.workflow_run_id,
313
+ ),
314
+ ContextVarToCopy(
315
+ name="ctx_worker_id",
316
+ value=action.worker_id,
317
+ ),
318
+ ContextVarToCopy(
319
+ name="ctx_action_key",
320
+ value=action.key,
321
+ ),
322
+ ],
323
+ self.thread_action_func,
324
+ ctx,
325
+ task,
326
+ action,
273
327
  )
274
- raise e
328
+
329
+ loop = asyncio.get_event_loop()
330
+ return await loop.run_in_executor(self.thread_pool, pfunc)
275
331
  finally:
276
332
  self.cleanup_run_id(action.key)
277
333
 
@@ -295,7 +351,7 @@ class Runner:
295
351
  while True:
296
352
  await self.log_thread_pool_status()
297
353
 
298
- for key in self.threads.keys():
354
+ for key in self.threads:
299
355
  if key not in self.tasks:
300
356
  logger.debug(f"Potential zombie thread found for key {key}")
301
357
 
@@ -350,6 +406,7 @@ class Runner:
350
406
  worker=self.worker_context,
351
407
  runs_client=self.runs_client,
352
408
  lifespan_context=self.lifespan_context,
409
+ log_sender=self.log_sender,
353
410
  )
354
411
 
355
412
  ## IMPORTANT: Keep this method's signature in sync with the wrapper in the OTel instrumentor
@@ -361,7 +418,8 @@ class Runner:
361
418
 
362
419
  if action_func:
363
420
  context = self.create_context(
364
- action, True if action_func.is_durable else False
421
+ action,
422
+ True if action_func.is_durable else False, # noqa: SIM210
365
423
  )
366
424
 
367
425
  self.contexts[action.key] = context
@@ -382,11 +440,12 @@ class Runner:
382
440
  task.add_done_callback(self.step_run_callback(action))
383
441
  self.tasks[action.key] = task
384
442
 
385
- try:
443
+ task_count.increment()
444
+
445
+ ## FIXME: Handle cancelled exceptions and other special exceptions
446
+ ## that we don't want to suppress here
447
+ with suppress(Exception):
386
448
  await task
387
- except Exception:
388
- # do nothing, this should be caught in the callback
389
- pass
390
449
 
391
450
  ## Once the step run completes, we need to remove the workflow spawn index
392
451
  ## so we don't leak memory
@@ -444,7 +503,7 @@ class Runner:
444
503
  res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(ident), exc)
445
504
  if res == 0:
446
505
  raise ValueError("Invalid thread ID")
447
- elif res != 1:
506
+ if res != 1:
448
507
  logger.error("PyThreadState_SetAsyncExc failed")
449
508
 
450
509
  # Call with exception set to 0 is needed to cleanup properly.
@@ -487,8 +546,16 @@ class Runner:
487
546
  self.cleanup_run_id(key)
488
547
 
489
548
  def serialize_output(self, output: Any) -> str:
549
+ if not output:
550
+ return ""
551
+
490
552
  if isinstance(output, BaseModel):
491
- return output.model_dump_json()
553
+ output = output.model_dump()
554
+
555
+ if not isinstance(output, dict):
556
+ raise IllegalTaskOutputError(
557
+ f"Tasks must return either a dictionary or a Pydantic BaseModel which can be serialized to a JSON object. Got object of type {type(output)} instead."
558
+ )
492
559
 
493
560
  if output is not None:
494
561
  try:
@@ -505,8 +572,3 @@ class Runner:
505
572
  logger.info(f"waiting for {running} tasks to finish...")
506
573
  await asyncio.sleep(1)
507
574
  running = len(self.tasks.keys())
508
-
509
-
510
- def pretty_format_exception(message: str, e: Exception) -> str:
511
- trace = "".join(traceback.format_exception(type(e), e, e.__traceback__))
512
- return f"{message}\n{trace}"
@@ -1,73 +1,114 @@
1
+ import asyncio
1
2
  import functools
2
3
  import logging
3
- from concurrent.futures import ThreadPoolExecutor
4
- from contextvars import ContextVar
4
+ from collections.abc import Awaitable, Callable
5
5
  from io import StringIO
6
- from typing import Any, Awaitable, Callable, ItemsView, ParamSpec, TypeVar
6
+ from typing import Literal, ParamSpec, TypeVar
7
+
8
+ from pydantic import BaseModel
7
9
 
8
10
  from hatchet_sdk.clients.events import EventClient
9
11
  from hatchet_sdk.logger import logger
10
- from hatchet_sdk.runnables.contextvars import ctx_step_run_id, ctx_workflow_run_id
12
+ from hatchet_sdk.runnables.contextvars import (
13
+ ctx_action_key,
14
+ ctx_step_run_id,
15
+ ctx_worker_id,
16
+ ctx_workflow_run_id,
17
+ )
18
+ from hatchet_sdk.utils.typing import STOP_LOOP, STOP_LOOP_TYPE
11
19
 
12
20
  T = TypeVar("T")
13
21
  P = ParamSpec("P")
14
22
 
15
23
 
24
+ class ContextVarToCopy(BaseModel):
25
+ name: Literal[
26
+ "ctx_workflow_run_id", "ctx_step_run_id", "ctx_action_key", "ctx_worker_id"
27
+ ]
28
+ value: str | None
29
+
30
+
16
31
  def copy_context_vars(
17
- ctx_vars: ItemsView[ContextVar[Any], Any],
32
+ ctx_vars: list[ContextVarToCopy],
18
33
  func: Callable[P, T],
19
34
  *args: P.args,
20
35
  **kwargs: P.kwargs,
21
36
  ) -> T:
22
- for var, value in ctx_vars:
23
- var.set(value)
37
+ for var in ctx_vars:
38
+ if var.name == "ctx_workflow_run_id":
39
+ ctx_workflow_run_id.set(var.value)
40
+ elif var.name == "ctx_step_run_id":
41
+ ctx_step_run_id.set(var.value)
42
+ elif var.name == "ctx_action_key":
43
+ ctx_action_key.set(var.value)
44
+ elif var.name == "ctx_worker_id":
45
+ ctx_worker_id.set(var.value)
46
+ else:
47
+ raise ValueError(f"Unknown context variable name: {var.name}")
48
+
24
49
  return func(*args, **kwargs)
25
50
 
26
51
 
27
- class InjectingFilter(logging.Filter):
28
- # For some reason, only the InjectingFilter has access to the contextvars method sr.get(),
29
- # otherwise we would use emit within the CustomLogHandler
30
- def filter(self, record: logging.LogRecord) -> bool:
31
- ## TODO: Change how we do this to not assign to the log record
32
- record.workflow_run_id = ctx_workflow_run_id.get()
33
- record.step_run_id = ctx_step_run_id.get()
34
- return True
52
+ class LogRecord(BaseModel):
53
+ message: str
54
+ step_run_id: str
35
55
 
36
56
 
37
- class CustomLogHandler(logging.StreamHandler): # type: ignore[type-arg]
38
- def __init__(self, event_client: EventClient, stream: StringIO | None = None):
39
- super().__init__(stream)
40
- self.logger_thread_pool = ThreadPoolExecutor(max_workers=1)
57
+ class AsyncLogSender:
58
+ def __init__(self, event_client: EventClient):
41
59
  self.event_client = event_client
60
+ self.q = asyncio.Queue[LogRecord | STOP_LOOP_TYPE](maxsize=1000)
61
+
62
+ async def consume(self) -> None:
63
+ while True:
64
+ record = await self.q.get()
65
+
66
+ if record == STOP_LOOP:
67
+ break
42
68
 
43
- def _log(self, line: str, step_run_id: str | None) -> None:
69
+ try:
70
+ self.event_client.log(
71
+ message=record.message, step_run_id=record.step_run_id
72
+ )
73
+ except Exception as e:
74
+ logger.error(f"Error logging: {e}")
75
+
76
+ def publish(self, record: LogRecord | STOP_LOOP_TYPE) -> None:
44
77
  try:
45
- if not step_run_id:
46
- return
78
+ self.q.put_nowait(record)
79
+ except asyncio.QueueFull:
80
+ logger.warning("Log queue is full, dropping log message")
81
+
47
82
 
48
- self.event_client.log(message=line, step_run_id=step_run_id)
49
- except Exception as e:
50
- logger.error(f"Error logging: {e}")
83
+ class CustomLogHandler(logging.StreamHandler): # type: ignore[type-arg]
84
+ def __init__(self, log_sender: AsyncLogSender, stream: StringIO):
85
+ super().__init__(stream)
86
+
87
+ self.log_sender = log_sender
51
88
 
52
89
  def emit(self, record: logging.LogRecord) -> None:
53
90
  super().emit(record)
54
91
 
55
92
  log_entry = self.format(record)
93
+ step_run_id = ctx_step_run_id.get()
94
+
95
+ if not step_run_id:
96
+ return
56
97
 
57
- ## TODO: Change how we do this to not assign to the log record
58
- self.logger_thread_pool.submit(self._log, log_entry, record.step_run_id) # type: ignore
98
+ self.log_sender.publish(LogRecord(message=log_entry, step_run_id=step_run_id))
59
99
 
60
100
 
61
101
  def capture_logs(
62
- logger: logging.Logger, event_client: "EventClient", func: Callable[P, Awaitable[T]]
102
+ logger: logging.Logger, log_sender: AsyncLogSender, func: Callable[P, Awaitable[T]]
63
103
  ) -> Callable[P, Awaitable[T]]:
64
104
  @functools.wraps(func)
65
105
  async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
66
106
  log_stream = StringIO()
67
- custom_handler = CustomLogHandler(event_client, log_stream)
107
+ custom_handler = CustomLogHandler(log_sender, log_stream)
68
108
  custom_handler.setLevel(logging.INFO)
69
- custom_handler.addFilter(InjectingFilter())
70
- logger.addHandler(custom_handler)
109
+
110
+ if not any(h for h in logger.handlers if isinstance(h, CustomLogHandler)):
111
+ logger.addHandler(custom_handler)
71
112
 
72
113
  try:
73
114
  result = await func(*args, **kwargs)