prefect-client 3.0.0rc13__py3-none-any.whl → 3.0.0rc14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefect/task_engine.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
3
  import threading
4
4
  import time
5
5
  from asyncio import CancelledError
6
- from contextlib import ExitStack, contextmanager
6
+ from contextlib import ExitStack, asynccontextmanager, contextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from functools import wraps
9
9
  from textwrap import dedent
@@ -31,12 +31,16 @@ import pendulum
31
31
  from typing_extensions import ParamSpec
32
32
 
33
33
  from prefect import Task
34
- from prefect.client.orchestration import SyncPrefectClient
34
+ from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
35
35
  from prefect.client.schemas import TaskRun
36
36
  from prefect.client.schemas.objects import State, TaskRunInput
37
+ from prefect.concurrency.asyncio import concurrency as aconcurrency
38
+ from prefect.concurrency.context import ConcurrencyContext
39
+ from prefect.concurrency.sync import concurrency
37
40
  from prefect.context import (
38
- ClientContext,
41
+ AsyncClientContext,
39
42
  FlowRunContext,
43
+ SyncClientContext,
40
44
  TaskRunContext,
41
45
  hydrated_context,
42
46
  )
@@ -59,6 +63,7 @@ from prefect.settings import (
59
63
  )
60
64
  from prefect.states import (
61
65
  AwaitingRetry,
66
+ Completed,
62
67
  Failed,
63
68
  Paused,
64
69
  Pending,
@@ -77,6 +82,7 @@ from prefect.utilities.engine import (
77
82
  _get_hook_name,
78
83
  emit_task_run_state_change_event,
79
84
  link_state_to_result,
85
+ propose_state,
80
86
  propose_state_sync,
81
87
  resolve_to_final_result,
82
88
  )
@@ -86,47 +92,745 @@ from prefect.utilities.timeout import timeout, timeout_async
86
92
  P = ParamSpec("P")
87
93
  R = TypeVar("R")
88
94
 
95
+ BACKOFF_MAX = 10
96
+
89
97
 
90
98
  class TaskRunTimeoutError(TimeoutError):
91
99
  """Raised when a task run exceeds its timeout."""
92
100
 
93
101
 
94
- @dataclass
95
- class TaskRunEngine(Generic[P, R]):
96
- task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
97
- logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
98
- parameters: Optional[Dict[str, Any]] = None
99
- task_run: Optional[TaskRun] = None
100
- retries: int = 0
101
- wait_for: Optional[Iterable[PrefectFuture]] = None
102
- context: Optional[Dict[str, Any]] = None
103
- # holds the return value from the user code
104
- _return_value: Union[R, Type[NotSet]] = NotSet
105
- # holds the exception raised by the user code, if any
106
- _raised: Union[Exception, Type[NotSet]] = NotSet
107
- _initial_run_context: Optional[TaskRunContext] = None
108
- _is_started: bool = False
109
- _client: Optional[SyncPrefectClient] = None
110
- _task_name_set: bool = False
111
- _last_event: Optional[PrefectEvent] = None
102
+ @dataclass
103
+ class BaseTaskRunEngine(Generic[P, R]):
104
+ task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
105
+ logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
106
+ parameters: Optional[Dict[str, Any]] = None
107
+ task_run: Optional[TaskRun] = None
108
+ retries: int = 0
109
+ wait_for: Optional[Iterable[PrefectFuture]] = None
110
+ context: Optional[Dict[str, Any]] = None
111
+ # holds the return value from the user code
112
+ _return_value: Union[R, Type[NotSet]] = NotSet
113
+ # holds the exception raised by the user code, if any
114
+ _raised: Union[Exception, Type[NotSet]] = NotSet
115
+ _initial_run_context: Optional[TaskRunContext] = None
116
+ _is_started: bool = False
117
+ _task_name_set: bool = False
118
+ _last_event: Optional[PrefectEvent] = None
119
+
120
+ def __post_init__(self):
121
+ if self.parameters is None:
122
+ self.parameters = {}
123
+
124
+ @property
125
+ def state(self) -> State:
126
+ if not self.task_run:
127
+ raise ValueError("Task run is not set")
128
+ return self.task_run.state
129
+
130
+ def is_cancelled(self) -> bool:
131
+ if (
132
+ self.context
133
+ and "cancel_event" in self.context
134
+ and isinstance(self.context["cancel_event"], threading.Event)
135
+ ):
136
+ return self.context["cancel_event"].is_set()
137
+ return False
138
+
139
+ def compute_transaction_key(self) -> Optional[str]:
140
+ key = None
141
+ if self.task.cache_policy:
142
+ flow_run_context = FlowRunContext.get()
143
+ task_run_context = TaskRunContext.get()
144
+
145
+ if flow_run_context:
146
+ parameters = flow_run_context.parameters
147
+ else:
148
+ parameters = None
149
+
150
+ key = self.task.cache_policy.compute_key(
151
+ task_ctx=task_run_context,
152
+ inputs=self.parameters,
153
+ flow_parameters=parameters,
154
+ )
155
+ elif self.task.result_storage_key is not None:
156
+ key = _format_user_supplied_storage_key(self.task.result_storage_key)
157
+ return key
158
+
159
+ def _resolve_parameters(self):
160
+ if not self.parameters:
161
+ return {}
162
+
163
+ resolved_parameters = {}
164
+ for parameter, value in self.parameters.items():
165
+ try:
166
+ resolved_parameters[parameter] = visit_collection(
167
+ value,
168
+ visit_fn=resolve_to_final_result,
169
+ return_data=True,
170
+ max_depth=-1,
171
+ remove_annotations=True,
172
+ context={},
173
+ )
174
+ except UpstreamTaskError:
175
+ raise
176
+ except Exception as exc:
177
+ raise PrefectException(
178
+ f"Failed to resolve inputs in parameter {parameter!r}. If your"
179
+ " parameter type is not supported, consider using the `quote`"
180
+ " annotation to skip resolution of inputs."
181
+ ) from exc
182
+
183
+ self.parameters = resolved_parameters
184
+
185
+ def _wait_for_dependencies(self):
186
+ if not self.wait_for:
187
+ return
188
+
189
+ visit_collection(
190
+ self.wait_for,
191
+ visit_fn=resolve_to_final_result,
192
+ return_data=False,
193
+ max_depth=-1,
194
+ remove_annotations=True,
195
+ context={"current_task_run": self.task_run, "current_task": self.task},
196
+ )
197
+
198
+ def record_terminal_state_timing(self, state: State) -> None:
199
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
200
+ if self.task_run.start_time and not self.task_run.end_time:
201
+ self.task_run.end_time = state.timestamp
202
+
203
+ if self.task_run.state.is_running():
204
+ self.task_run.total_run_time += (
205
+ state.timestamp - self.task_run.state.timestamp
206
+ )
207
+
208
+ def is_running(self) -> bool:
209
+ """Whether or not the engine is currently running a task."""
210
+ if (task_run := getattr(self, "task_run", None)) is None:
211
+ return False
212
+ return task_run.state.is_running() or task_run.state.is_scheduled()
213
+
214
+ def log_finished_message(self):
215
+ # If debugging, use the more complete `repr` than the usual `str` description
216
+ display_state = repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
217
+ level = logging.INFO if self.state.is_completed() else logging.ERROR
218
+ msg = f"Finished in state {display_state}"
219
+ if self.state.is_pending():
220
+ msg += (
221
+ "\nPlease wait for all submitted tasks to complete"
222
+ " before exiting your flow by calling `.wait()` on the "
223
+ "`PrefectFuture` returned from your `.submit()` calls."
224
+ )
225
+ msg += dedent(
226
+ """
227
+
228
+ Example:
229
+
230
+ from prefect import flow, task
231
+
232
+ @task
233
+ def say_hello(name):
234
+ print f"Hello, {name}!"
235
+
236
+ @flow
237
+ def example_flow():
238
+ future = say_hello.submit(name="Marvin)
239
+ future.wait()
240
+
241
+ example_flow()
242
+ """
243
+ )
244
+ self.logger.log(
245
+ level=level,
246
+ msg=msg,
247
+ )
248
+
249
+ def handle_rollback(self, txn: Transaction) -> None:
250
+ assert self.task_run is not None
251
+
252
+ rolled_back_state = Completed(
253
+ name="RolledBack",
254
+ message="Task rolled back as part of transaction",
255
+ )
256
+
257
+ self._last_event = emit_task_run_state_change_event(
258
+ task_run=self.task_run,
259
+ initial_state=self.state,
260
+ validated_state=rolled_back_state,
261
+ follows=self._last_event,
262
+ )
263
+
264
+
265
+ @dataclass
266
+ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
267
+ _client: Optional[SyncPrefectClient] = None
268
+
269
+ @property
270
+ def client(self) -> SyncPrefectClient:
271
+ if not self._is_started or self._client is None:
272
+ raise RuntimeError("Engine has not started.")
273
+ return self._client
274
+
275
+ def can_retry(self, exc: Exception) -> bool:
276
+ retry_condition: Optional[
277
+ Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
278
+ ] = self.task.retry_condition_fn
279
+ if not self.task_run:
280
+ raise ValueError("Task run is not set")
281
+ try:
282
+ self.logger.debug(
283
+ f"Running `retry_condition_fn` check {retry_condition!r} for task"
284
+ f" {self.task.name!r}"
285
+ )
286
+ state = Failed(
287
+ data=exc,
288
+ message=f"Task run encountered unexpected exception: {repr(exc)}",
289
+ )
290
+ if inspect.iscoroutinefunction(retry_condition):
291
+ should_retry = run_coro_as_sync(
292
+ retry_condition(self.task, self.task_run, state)
293
+ )
294
+ elif inspect.isfunction(retry_condition):
295
+ should_retry = retry_condition(self.task, self.task_run, state)
296
+ else:
297
+ should_retry = not retry_condition
298
+ return should_retry
299
+ except Exception:
300
+ self.logger.error(
301
+ (
302
+ "An error was encountered while running `retry_condition_fn` check"
303
+ f" '{retry_condition!r}' for task {self.task.name!r}"
304
+ ),
305
+ exc_info=True,
306
+ )
307
+ return False
308
+
309
+ def call_hooks(self, state: Optional[State] = None):
310
+ if state is None:
311
+ state = self.state
312
+ task = self.task
313
+ task_run = self.task_run
314
+
315
+ if not task_run:
316
+ raise ValueError("Task run is not set")
317
+
318
+ if state.is_failed() and task.on_failure_hooks:
319
+ hooks = task.on_failure_hooks
320
+ elif state.is_completed() and task.on_completion_hooks:
321
+ hooks = task.on_completion_hooks
322
+ else:
323
+ hooks = None
324
+
325
+ for hook in hooks or []:
326
+ hook_name = _get_hook_name(hook)
327
+
328
+ try:
329
+ self.logger.info(
330
+ f"Running hook {hook_name!r} in response to entering state"
331
+ f" {state.name!r}"
332
+ )
333
+ result = hook(task, task_run, state)
334
+ if inspect.isawaitable(result):
335
+ run_coro_as_sync(result)
336
+ except Exception:
337
+ self.logger.error(
338
+ f"An error was encountered while running hook {hook_name!r}",
339
+ exc_info=True,
340
+ )
341
+ else:
342
+ self.logger.info(f"Hook {hook_name!r} finished running successfully")
343
+
344
+ def begin_run(self):
345
+ try:
346
+ self._resolve_parameters()
347
+ self._wait_for_dependencies()
348
+ except UpstreamTaskError as upstream_exc:
349
+ state = self.set_state(
350
+ Pending(
351
+ name="NotReady",
352
+ message=str(upstream_exc),
353
+ ),
354
+ # if orchestrating a run already in a pending state, force orchestration to
355
+ # update the state name
356
+ force=self.state.is_pending(),
357
+ )
358
+ return
359
+
360
+ new_state = Running()
361
+
362
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
363
+ self.task_run.start_time = new_state.timestamp
364
+ self.task_run.run_count += 1
365
+
366
+ flow_run_context = FlowRunContext.get()
367
+ if flow_run_context:
368
+ # Carry forward any task run information from the flow run
369
+ flow_run = flow_run_context.flow_run
370
+ self.task_run.flow_run_run_count = flow_run.run_count
371
+
372
+ state = self.set_state(new_state)
373
+
374
+ # TODO: this is temporary until the API stops rejecting state transitions
375
+ # and the client / transaction store becomes the source of truth
376
+ # this is a bandaid caused by the API storing a Completed state with a bad
377
+ # result reference that no longer exists
378
+ if state.is_completed():
379
+ try:
380
+ state.result(retry_result_failure=False, _sync=True)
381
+ except Exception:
382
+ state = self.set_state(new_state, force=True)
383
+
384
+ backoff_count = 0
385
+
386
+ # TODO: Could this listen for state change events instead of polling?
387
+ while state.is_pending() or state.is_paused():
388
+ if backoff_count < BACKOFF_MAX:
389
+ backoff_count += 1
390
+ interval = clamped_poisson_interval(
391
+ average_interval=backoff_count, clamping_factor=0.3
392
+ )
393
+ time.sleep(interval)
394
+ state = self.set_state(new_state)
395
+
396
+ def set_state(self, state: State, force: bool = False) -> State:
397
+ last_state = self.state
398
+ if not self.task_run:
399
+ raise ValueError("Task run is not set")
400
+
401
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
402
+ self.task_run.state = new_state = state
403
+
404
+ # Ensure that the state_details are populated with the current run IDs
405
+ new_state.state_details.task_run_id = self.task_run.id
406
+ new_state.state_details.flow_run_id = self.task_run.flow_run_id
407
+
408
+ # Predictively update the de-normalized task_run.state_* attributes
409
+ self.task_run.state_id = new_state.id
410
+ self.task_run.state_type = new_state.type
411
+ self.task_run.state_name = new_state.name
412
+ else:
413
+ try:
414
+ new_state = propose_state_sync(
415
+ self.client, state, task_run_id=self.task_run.id, force=force
416
+ )
417
+ except Pause as exc:
418
+ # We shouldn't get a pause signal without a state, but if this happens,
419
+ # just use a Paused state to assume an in-process pause.
420
+ new_state = exc.state if exc.state else Paused()
421
+ if new_state.state_details.pause_reschedule:
422
+ # If we're being asked to pause and reschedule, we should exit the
423
+ # task and expect to be resumed later.
424
+ raise
425
+
426
+ # currently this is a hack to keep a reference to the state object
427
+ # that has an in-memory result attached to it; using the API state
428
+ # could result in losing that reference
429
+ self.task_run.state = new_state
430
+
431
+ # emit a state change event
432
+ self._last_event = emit_task_run_state_change_event(
433
+ task_run=self.task_run,
434
+ initial_state=last_state,
435
+ validated_state=self.task_run.state,
436
+ follows=self._last_event,
437
+ )
438
+
439
+ return new_state
440
+
441
+ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
442
+ if self._return_value is not NotSet:
443
+ # if the return value is a BaseResult, we need to fetch it
444
+ if isinstance(self._return_value, BaseResult):
445
+ _result = self._return_value.get()
446
+ if inspect.isawaitable(_result):
447
+ _result = run_coro_as_sync(_result)
448
+ return _result
449
+
450
+ # otherwise, return the value as is
451
+ return self._return_value
452
+
453
+ if self._raised is not NotSet:
454
+ # if the task raised an exception, raise it
455
+ if raise_on_failure:
456
+ raise self._raised
457
+
458
+ # otherwise, return the exception
459
+ return self._raised
460
+
461
+ def handle_success(self, result: R, transaction: Transaction) -> R:
462
+ result_factory = getattr(TaskRunContext.get(), "result_factory", None)
463
+ if result_factory is None:
464
+ raise ValueError("Result factory is not set")
465
+
466
+ if self.task.cache_expiration is not None:
467
+ expiration = pendulum.now("utc") + self.task.cache_expiration
468
+ else:
469
+ expiration = None
470
+
471
+ terminal_state = run_coro_as_sync(
472
+ return_value_to_state(
473
+ result,
474
+ result_factory=result_factory,
475
+ key=transaction.key,
476
+ expiration=expiration,
477
+ # defer persistence to transaction commit
478
+ defer_persistence=True,
479
+ )
480
+ )
481
+ transaction.stage(
482
+ terminal_state.data,
483
+ on_rollback_hooks=[self.handle_rollback]
484
+ + [
485
+ _with_transaction_hook_logging(hook, "rollback", self.logger)
486
+ for hook in self.task.on_rollback_hooks
487
+ ],
488
+ on_commit_hooks=[
489
+ _with_transaction_hook_logging(hook, "commit", self.logger)
490
+ for hook in self.task.on_commit_hooks
491
+ ],
492
+ )
493
+ if transaction.is_committed():
494
+ terminal_state.name = "Cached"
495
+
496
+ self.record_terminal_state_timing(terminal_state)
497
+ self.set_state(terminal_state)
498
+ self._return_value = result
499
+ return result
500
+
501
+ def handle_retry(self, exc: Exception) -> bool:
502
+ """Handle any task run retries.
503
+
504
+ - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
505
+ - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
506
+ - If the task has no retries left, or the retry condition is not met, return False.
507
+ """
508
+ if self.retries < self.task.retries and self.can_retry(exc):
509
+ if self.task.retry_delay_seconds:
510
+ delay = (
511
+ self.task.retry_delay_seconds[
512
+ min(self.retries, len(self.task.retry_delay_seconds) - 1)
513
+ ] # repeat final delay value if attempts exceed specified delays
514
+ if isinstance(self.task.retry_delay_seconds, Sequence)
515
+ else self.task.retry_delay_seconds
516
+ )
517
+ new_state = AwaitingRetry(
518
+ scheduled_time=pendulum.now("utc").add(seconds=delay)
519
+ )
520
+ else:
521
+ delay = None
522
+ new_state = Retrying()
523
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
524
+ self.task_run.run_count += 1
525
+
526
+ self.logger.info(
527
+ "Task run failed with exception: %r - " "Retry %s/%s will start %s",
528
+ exc,
529
+ self.retries + 1,
530
+ self.task.retries,
531
+ str(delay) + " second(s) from now" if delay else "immediately",
532
+ )
533
+
534
+ self.set_state(new_state, force=True)
535
+ self.retries = self.retries + 1
536
+ return True
537
+ elif self.retries >= self.task.retries:
538
+ self.logger.error(
539
+ "Task run failed with exception: %r - Retries are exhausted",
540
+ exc,
541
+ exc_info=True,
542
+ )
543
+ return False
544
+
545
+ return False
546
+
547
+ def handle_exception(self, exc: Exception) -> None:
548
+ # If the task fails, and we have retries left, set the task to retrying.
549
+ if not self.handle_retry(exc):
550
+ # If the task has no retries left, or the retry condition is not met, set the task to failed.
551
+ context = TaskRunContext.get()
552
+ state = run_coro_as_sync(
553
+ exception_to_failed_state(
554
+ exc,
555
+ message="Task run encountered an exception",
556
+ result_factory=getattr(context, "result_factory", None),
557
+ )
558
+ )
559
+ self.record_terminal_state_timing(state)
560
+ self.set_state(state)
561
+ self._raised = exc
562
+
563
+ def handle_timeout(self, exc: TimeoutError) -> None:
564
+ if not self.handle_retry(exc):
565
+ if isinstance(exc, TaskRunTimeoutError):
566
+ message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
567
+ else:
568
+ message = f"Task run failed due to timeout: {exc!r}"
569
+ self.logger.error(message)
570
+ state = Failed(
571
+ data=exc,
572
+ message=message,
573
+ name="TimedOut",
574
+ )
575
+ self.set_state(state)
576
+ self._raised = exc
577
+
578
+ def handle_crash(self, exc: BaseException) -> None:
579
+ state = run_coro_as_sync(exception_to_crashed_state(exc))
580
+ self.logger.error(f"Crash detected! {state.message}")
581
+ self.logger.debug("Crash details:", exc_info=exc)
582
+ self.record_terminal_state_timing(state)
583
+ self.set_state(state, force=True)
584
+ self._raised = exc
585
+
586
+ @contextmanager
587
+ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
588
+ from prefect.utilities.engine import (
589
+ _resolve_custom_task_run_name,
590
+ should_log_prints,
591
+ )
592
+
593
+ if client is None:
594
+ client = self.client
595
+ if not self.task_run:
596
+ raise ValueError("Task run is not set")
597
+
598
+ if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
599
+ self.task_run = client.read_task_run(self.task_run.id)
600
+ with ExitStack() as stack:
601
+ if log_prints := should_log_prints(self.task):
602
+ stack.enter_context(patch_print())
603
+ stack.enter_context(
604
+ TaskRunContext(
605
+ task=self.task,
606
+ log_prints=log_prints,
607
+ task_run=self.task_run,
608
+ parameters=self.parameters,
609
+ result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore
610
+ client=client,
611
+ )
612
+ )
613
+ stack.enter_context(ConcurrencyContext())
614
+
615
+ self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
616
+
617
+ if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
618
+ # update the task run name if necessary
619
+ if not self._task_name_set and self.task.task_run_name:
620
+ task_run_name = _resolve_custom_task_run_name(
621
+ task=self.task, parameters=self.parameters
622
+ )
623
+ self.client.set_task_run_name(
624
+ task_run_id=self.task_run.id, name=task_run_name
625
+ )
626
+ self.logger.extra["task_run_name"] = task_run_name
627
+ self.logger.debug(
628
+ f"Renamed task run {self.task_run.name!r} to {task_run_name!r}"
629
+ )
630
+ self.task_run.name = task_run_name
631
+ self._task_name_set = True
632
+ yield
633
+
634
+ @contextmanager
635
+ def initialize_run(
636
+ self,
637
+ task_run_id: Optional[UUID] = None,
638
+ dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
639
+ ) -> Generator["SyncTaskRunEngine", Any, Any]:
640
+ """
641
+ Enters a client context and creates a task run if needed.
642
+ """
643
+
644
+ with hydrated_context(self.context):
645
+ with SyncClientContext.get_or_create() as client_ctx:
646
+ self._client = client_ctx.client
647
+ self._is_started = True
648
+ try:
649
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
650
+ from prefect.utilities.engine import (
651
+ _resolve_custom_task_run_name,
652
+ )
653
+
654
+ task_run_name = (
655
+ _resolve_custom_task_run_name(
656
+ task=self.task, parameters=self.parameters
657
+ )
658
+ if self.task.task_run_name
659
+ else None
660
+ )
661
+
662
+ if self.task_run and task_run_name:
663
+ self.task_run.name = task_run_name
664
+
665
+ if not self.task_run:
666
+ self.task_run = run_coro_as_sync(
667
+ self.task.create_local_run(
668
+ id=task_run_id,
669
+ parameters=self.parameters,
670
+ flow_run_context=FlowRunContext.get(),
671
+ parent_task_run_context=TaskRunContext.get(),
672
+ wait_for=self.wait_for,
673
+ extra_task_inputs=dependencies,
674
+ task_run_name=task_run_name,
675
+ )
676
+ )
677
+ else:
678
+ if not self.task_run:
679
+ self.task_run = run_coro_as_sync(
680
+ self.task.create_run(
681
+ id=task_run_id,
682
+ parameters=self.parameters,
683
+ flow_run_context=FlowRunContext.get(),
684
+ parent_task_run_context=TaskRunContext.get(),
685
+ wait_for=self.wait_for,
686
+ extra_task_inputs=dependencies,
687
+ )
688
+ )
689
+ # Emit an event to capture that the task run was in the `PENDING` state.
690
+ self._last_event = emit_task_run_state_change_event(
691
+ task_run=self.task_run,
692
+ initial_state=None,
693
+ validated_state=self.task_run.state,
694
+ )
695
+
696
+ with self.setup_run_context():
697
+ # setup_run_context might update the task run name, so log creation here
698
+ self.logger.info(
699
+ f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
700
+ )
701
+ yield self
702
+
703
+ except TerminationSignal as exc:
704
+ # TerminationSignals are caught and handled as crashes
705
+ self.handle_crash(exc)
706
+ raise exc
707
+
708
+ except Exception:
709
+ # regular exceptions are caught and re-raised to the user
710
+ raise
711
+ except (Pause, Abort) as exc:
712
+ # Do not capture internal signals as crashes
713
+ if isinstance(exc, Abort):
714
+ self.logger.error("Task run was aborted: %s", exc)
715
+ raise
716
+ except GeneratorExit:
717
+ # Do not capture generator exits as crashes
718
+ raise
719
+ except BaseException as exc:
720
+ # BaseExceptions are caught and handled as crashes
721
+ self.handle_crash(exc)
722
+ raise
723
+ finally:
724
+ self.log_finished_message()
725
+ self._is_started = False
726
+ self._client = None
727
+
728
+ async def wait_until_ready(self):
729
+ """Waits until the scheduled time (if its the future), then enters Running."""
730
+ if scheduled_time := self.state.state_details.scheduled_time:
731
+ sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
732
+ await anyio.sleep(sleep_time if sleep_time > 0 else 0)
733
+ self.set_state(
734
+ Retrying() if self.state.name == "AwaitingRetry" else Running(),
735
+ force=True,
736
+ )
737
+
738
+ # --------------------------
739
+ #
740
+ # The following methods compose the main task run loop
741
+ #
742
+ # --------------------------
743
+
744
+ @contextmanager
745
+ def start(
746
+ self,
747
+ task_run_id: Optional[UUID] = None,
748
+ dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
749
+ ) -> Generator[None, None, None]:
750
+ with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
751
+ self.begin_run()
752
+ try:
753
+ yield
754
+ finally:
755
+ self.call_hooks()
756
+
757
+ @contextmanager
758
+ def transaction_context(self) -> Generator[Transaction, None, None]:
759
+ result_factory = getattr(TaskRunContext.get(), "result_factory", None)
760
+
761
+ # refresh cache setting is now repurposes as overwrite transaction record
762
+ overwrite = (
763
+ self.task.refresh_cache
764
+ if self.task.refresh_cache is not None
765
+ else PREFECT_TASKS_REFRESH_CACHE.value()
766
+ )
767
+ with transaction(
768
+ key=self.compute_transaction_key(),
769
+ store=ResultFactoryStore(result_factory=result_factory),
770
+ overwrite=overwrite,
771
+ logger=self.logger,
772
+ ) as txn:
773
+ yield txn
774
+
775
+ @contextmanager
776
+ def run_context(self):
777
+ # reenter the run context to ensure it is up to date for every run
778
+ with self.setup_run_context():
779
+ try:
780
+ with timeout(
781
+ seconds=self.task.timeout_seconds,
782
+ timeout_exc_type=TaskRunTimeoutError,
783
+ ):
784
+ self.logger.debug(
785
+ f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
786
+ )
787
+ if self.is_cancelled():
788
+ raise CancelledError("Task run cancelled by the task runner")
789
+
790
+ yield self
791
+ except TimeoutError as exc:
792
+ self.handle_timeout(exc)
793
+ except Exception as exc:
794
+ self.handle_exception(exc)
795
+
796
+ def call_task_fn(
797
+ self, transaction: Transaction
798
+ ) -> Union[R, Coroutine[Any, Any, R]]:
799
+ """
800
+ Convenience method to call the task function. Returns a coroutine if the
801
+ task is async.
802
+ """
803
+ parameters = self.parameters or {}
804
+ if transaction.is_committed():
805
+ result = transaction.read()
806
+ else:
807
+ if (
808
+ PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value()
809
+ and self.task.tags
810
+ ):
811
+ # Acquire a concurrency slot for each tag, but only if a limit
812
+ # matching the tag already exists.
813
+ with concurrency(
814
+ list(self.task.tags), occupy=1, create_if_missing=False
815
+ ):
816
+ result = call_with_parameters(self.task.fn, parameters)
817
+ else:
818
+ result = call_with_parameters(self.task.fn, parameters)
819
+ self.handle_success(result, transaction=transaction)
820
+ return result
112
821
 
113
- def __post_init__(self):
114
- if self.parameters is None:
115
- self.parameters = {}
822
+
823
+ @dataclass
824
+ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
825
+ _client: Optional[PrefectClient] = None
116
826
 
117
827
  @property
118
- def client(self) -> SyncPrefectClient:
828
+ def client(self) -> PrefectClient:
119
829
  if not self._is_started or self._client is None:
120
830
  raise RuntimeError("Engine has not started.")
121
831
  return self._client
122
832
 
123
- @property
124
- def state(self) -> State:
125
- if not self.task_run:
126
- raise ValueError("Task run is not set")
127
- return self.task_run.state
128
-
129
- def can_retry(self, exc: Exception) -> bool:
833
+ async def can_retry(self, exc: Exception) -> bool:
130
834
  retry_condition: Optional[
131
835
  Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
132
836
  ] = self.task.retry_condition_fn
@@ -142,14 +846,13 @@ class TaskRunEngine(Generic[P, R]):
142
846
  message=f"Task run encountered unexpected exception: {repr(exc)}",
143
847
  )
144
848
  if inspect.iscoroutinefunction(retry_condition):
145
- should_retry = run_coro_as_sync(
146
- retry_condition(self.task, self.task_run, state)
147
- )
849
+ should_retry = await retry_condition(self.task, self.task_run, state)
148
850
  elif inspect.isfunction(retry_condition):
149
851
  should_retry = retry_condition(self.task, self.task_run, state)
150
852
  else:
151
853
  should_retry = not retry_condition
152
854
  return should_retry
855
+
153
856
  except Exception:
154
857
  self.logger.error(
155
858
  (
@@ -160,16 +863,7 @@ class TaskRunEngine(Generic[P, R]):
160
863
  )
161
864
  return False
162
865
 
163
- def is_cancelled(self) -> bool:
164
- if (
165
- self.context
166
- and "cancel_event" in self.context
167
- and isinstance(self.context["cancel_event"], threading.Event)
168
- ):
169
- return self.context["cancel_event"].is_set()
170
- return False
171
-
172
- def call_hooks(self, state: Optional[State] = None):
866
+ async def call_hooks(self, state: Optional[State] = None):
173
867
  if state is None:
174
868
  state = self.state
175
869
  task = self.task
@@ -195,7 +889,7 @@ class TaskRunEngine(Generic[P, R]):
195
889
  )
196
890
  result = hook(task, task_run, state)
197
891
  if inspect.isawaitable(result):
198
- run_coro_as_sync(result)
892
+ await result
199
893
  except Exception:
200
894
  self.logger.error(
201
895
  f"An error was encountered while running hook {hook_name!r}",
@@ -204,71 +898,12 @@ class TaskRunEngine(Generic[P, R]):
204
898
  else:
205
899
  self.logger.info(f"Hook {hook_name!r} finished running successfully")
206
900
 
207
- def compute_transaction_key(self) -> Optional[str]:
208
- key = None
209
- if self.task.cache_policy:
210
- flow_run_context = FlowRunContext.get()
211
- task_run_context = TaskRunContext.get()
212
-
213
- if flow_run_context:
214
- parameters = flow_run_context.parameters
215
- else:
216
- parameters = None
217
-
218
- key = self.task.cache_policy.compute_key(
219
- task_ctx=task_run_context,
220
- inputs=self.parameters,
221
- flow_parameters=parameters,
222
- )
223
- elif self.task.result_storage_key is not None:
224
- key = _format_user_supplied_storage_key(self.task.result_storage_key)
225
- return key
226
-
227
- def _resolve_parameters(self):
228
- if not self.parameters:
229
- return {}
230
-
231
- resolved_parameters = {}
232
- for parameter, value in self.parameters.items():
233
- try:
234
- resolved_parameters[parameter] = visit_collection(
235
- value,
236
- visit_fn=resolve_to_final_result,
237
- return_data=True,
238
- max_depth=-1,
239
- remove_annotations=True,
240
- context={},
241
- )
242
- except UpstreamTaskError:
243
- raise
244
- except Exception as exc:
245
- raise PrefectException(
246
- f"Failed to resolve inputs in parameter {parameter!r}. If your"
247
- " parameter type is not supported, consider using the `quote`"
248
- " annotation to skip resolution of inputs."
249
- ) from exc
250
-
251
- self.parameters = resolved_parameters
252
-
253
- def _wait_for_dependencies(self):
254
- if not self.wait_for:
255
- return
256
-
257
- visit_collection(
258
- self.wait_for,
259
- visit_fn=resolve_to_final_result,
260
- return_data=False,
261
- max_depth=-1,
262
- remove_annotations=True,
263
- context={"current_task_run": self.task_run, "current_task": self.task},
264
- )
265
-
266
- def begin_run(self):
901
+ async def begin_run(self):
267
902
  try:
268
903
  self._resolve_parameters()
269
904
  self._wait_for_dependencies()
270
905
  except UpstreamTaskError as upstream_exc:
271
- state = self.set_state(
906
+ state = await self.set_state(
272
907
  Pending(
273
908
  name="NotReady",
274
909
  message=str(upstream_exc),
@@ -291,7 +926,7 @@ class TaskRunEngine(Generic[P, R]):
291
926
  flow_run = flow_run_context.flow_run
292
927
  self.task_run.flow_run_run_count = flow_run.run_count
293
928
 
294
- state = self.set_state(new_state)
929
+ state = await self.set_state(new_state)
295
930
 
296
931
  # TODO: this is temporary until the API stops rejecting state transitions
297
932
  # and the client / transaction store becomes the source of truth
@@ -299,11 +934,10 @@ class TaskRunEngine(Generic[P, R]):
299
934
  # result reference that no longer exists
300
935
  if state.is_completed():
301
936
  try:
302
- state.result(retry_result_failure=False, _sync=True)
937
+ await state.result(retry_result_failure=False)
303
938
  except Exception:
304
- state = self.set_state(new_state, force=True)
939
+ state = await self.set_state(new_state, force=True)
305
940
 
306
- BACKOFF_MAX = 10
307
941
  backoff_count = 0
308
942
 
309
943
  # TODO: Could this listen for state change events instead of polling?
@@ -313,10 +947,10 @@ class TaskRunEngine(Generic[P, R]):
313
947
  interval = clamped_poisson_interval(
314
948
  average_interval=backoff_count, clamping_factor=0.3
315
949
  )
316
- time.sleep(interval)
317
- state = self.set_state(new_state)
950
+ await anyio.sleep(interval)
951
+ state = await self.set_state(new_state)
318
952
 
319
- def set_state(self, state: State, force: bool = False) -> State:
953
+ async def set_state(self, state: State, force: bool = False) -> State:
320
954
  last_state = self.state
321
955
  if not self.task_run:
322
956
  raise ValueError("Task run is not set")
@@ -334,7 +968,7 @@ class TaskRunEngine(Generic[P, R]):
334
968
  self.task_run.state_name = new_state.name
335
969
  else:
336
970
  try:
337
- new_state = propose_state_sync(
971
+ new_state = await propose_state(
338
972
  self.client, state, task_run_id=self.task_run.id, force=force
339
973
  )
340
974
  except Pause as exc:
@@ -361,14 +995,11 @@ class TaskRunEngine(Generic[P, R]):
361
995
 
362
996
  return new_state
363
997
 
364
- def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
998
+ async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
365
999
  if self._return_value is not NotSet:
366
1000
  # if the return value is a BaseResult, we need to fetch it
367
1001
  if isinstance(self._return_value, BaseResult):
368
- _result = self._return_value.get()
369
- if inspect.isawaitable(_result):
370
- _result = run_coro_as_sync(_result)
371
- return _result
1002
+ return await self._return_value.get()
372
1003
 
373
1004
  # otherwise, return the value as is
374
1005
  return self._return_value
@@ -381,7 +1012,7 @@ class TaskRunEngine(Generic[P, R]):
381
1012
  # otherwise, return the exception
382
1013
  return self._raised
383
1014
 
384
- def handle_success(self, result: R, transaction: Transaction) -> R:
1015
+ async def handle_success(self, result: R, transaction: Transaction) -> R:
385
1016
  result_factory = getattr(TaskRunContext.get(), "result_factory", None)
386
1017
  if result_factory is None:
387
1018
  raise ValueError("Result factory is not set")
@@ -391,19 +1022,18 @@ class TaskRunEngine(Generic[P, R]):
391
1022
  else:
392
1023
  expiration = None
393
1024
 
394
- terminal_state = run_coro_as_sync(
395
- return_value_to_state(
396
- result,
397
- result_factory=result_factory,
398
- key=transaction.key,
399
- expiration=expiration,
400
- # defer persistence to transaction commit
401
- defer_persistence=True,
402
- )
1025
+ terminal_state = await return_value_to_state(
1026
+ result,
1027
+ result_factory=result_factory,
1028
+ key=transaction.key,
1029
+ expiration=expiration,
1030
+ # defer persistence to transaction commit
1031
+ defer_persistence=True,
403
1032
  )
404
1033
  transaction.stage(
405
1034
  terminal_state.data,
406
- on_rollback_hooks=[
1035
+ on_rollback_hooks=[self.handle_rollback]
1036
+ + [
407
1037
  _with_transaction_hook_logging(hook, "rollback", self.logger)
408
1038
  for hook in self.task.on_rollback_hooks
409
1039
  ],
@@ -416,18 +1046,18 @@ class TaskRunEngine(Generic[P, R]):
416
1046
  terminal_state.name = "Cached"
417
1047
 
418
1048
  self.record_terminal_state_timing(terminal_state)
419
- self.set_state(terminal_state)
1049
+ await self.set_state(terminal_state)
420
1050
  self._return_value = result
421
1051
  return result
422
1052
 
423
- def handle_retry(self, exc: Exception) -> bool:
1053
+ async def handle_retry(self, exc: Exception) -> bool:
424
1054
  """Handle any task run retries.
425
1055
 
426
1056
  - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
427
- - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
1057
+ - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
428
1058
  - If the task has no retries left, or the retry condition is not met, return False.
429
1059
  """
430
- if self.retries < self.task.retries and self.can_retry(exc):
1060
+ if self.retries < self.task.retries and await self.can_retry(exc):
431
1061
  if self.task.retry_delay_seconds:
432
1062
  delay = (
433
1063
  self.task.retry_delay_seconds[
@@ -453,7 +1083,7 @@ class TaskRunEngine(Generic[P, R]):
453
1083
  str(delay) + " second(s) from now" if delay else "immediately",
454
1084
  )
455
1085
 
456
- self.set_state(new_state, force=True)
1086
+ await self.set_state(new_state, force=True)
457
1087
  self.retries = self.retries + 1
458
1088
  return True
459
1089
  elif self.retries >= self.task.retries:
@@ -466,24 +1096,22 @@ class TaskRunEngine(Generic[P, R]):
466
1096
 
467
1097
  return False
468
1098
 
469
- def handle_exception(self, exc: Exception) -> None:
1099
+ async def handle_exception(self, exc: Exception) -> None:
470
1100
  # If the task fails, and we have retries left, set the task to retrying.
471
- if not self.handle_retry(exc):
1101
+ if not await self.handle_retry(exc):
472
1102
  # If the task has no retries left, or the retry condition is not met, set the task to failed.
473
1103
  context = TaskRunContext.get()
474
- state = run_coro_as_sync(
475
- exception_to_failed_state(
476
- exc,
477
- message="Task run encountered an exception",
478
- result_factory=getattr(context, "result_factory", None),
479
- )
1104
+ state = await exception_to_failed_state(
1105
+ exc,
1106
+ message="Task run encountered an exception",
1107
+ result_factory=getattr(context, "result_factory", None),
480
1108
  )
481
1109
  self.record_terminal_state_timing(state)
482
- self.set_state(state)
1110
+ await self.set_state(state)
483
1111
  self._raised = exc
484
1112
 
485
- def handle_timeout(self, exc: TimeoutError) -> None:
486
- if not self.handle_retry(exc):
1113
+ async def handle_timeout(self, exc: TimeoutError) -> None:
1114
+ if not await self.handle_retry(exc):
487
1115
  if isinstance(exc, TaskRunTimeoutError):
488
1116
  message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
489
1117
  else:
@@ -494,29 +1122,19 @@ class TaskRunEngine(Generic[P, R]):
494
1122
  message=message,
495
1123
  name="TimedOut",
496
1124
  )
497
- self.set_state(state)
1125
+ await self.set_state(state)
498
1126
  self._raised = exc
499
1127
 
500
- def handle_crash(self, exc: BaseException) -> None:
501
- state = run_coro_as_sync(exception_to_crashed_state(exc))
1128
+ async def handle_crash(self, exc: BaseException) -> None:
1129
+ state = await exception_to_crashed_state(exc)
502
1130
  self.logger.error(f"Crash detected! {state.message}")
503
1131
  self.logger.debug("Crash details:", exc_info=exc)
504
1132
  self.record_terminal_state_timing(state)
505
- self.set_state(state, force=True)
1133
+ await self.set_state(state, force=True)
506
1134
  self._raised = exc
507
1135
 
508
- def record_terminal_state_timing(self, state: State) -> None:
509
- if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
510
- if self.task_run.start_time and not self.task_run.end_time:
511
- self.task_run.end_time = state.timestamp
512
-
513
- if self.task_run.state.is_running():
514
- self.task_run.total_run_time += (
515
- state.timestamp - self.task_run.state.timestamp
516
- )
517
-
518
- @contextmanager
519
- def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
1136
+ @asynccontextmanager
1137
+ async def setup_run_context(self, client: Optional[PrefectClient] = None):
520
1138
  from prefect.utilities.engine import (
521
1139
  _resolve_custom_task_run_name,
522
1140
  should_log_prints,
@@ -528,7 +1146,7 @@ class TaskRunEngine(Generic[P, R]):
528
1146
  raise ValueError("Task run is not set")
529
1147
 
530
1148
  if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
531
- self.task_run = client.read_task_run(self.task_run.id)
1149
+ self.task_run = await client.read_task_run(self.task_run.id)
532
1150
  with ExitStack() as stack:
533
1151
  if log_prints := should_log_prints(self.task):
534
1152
  stack.enter_context(patch_print())
@@ -538,10 +1156,11 @@ class TaskRunEngine(Generic[P, R]):
538
1156
  log_prints=log_prints,
539
1157
  task_run=self.task_run,
540
1158
  parameters=self.parameters,
541
- result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore
1159
+ result_factory=await ResultFactory.from_task(self.task), # type: ignore
542
1160
  client=client,
543
1161
  )
544
1162
  )
1163
+ stack.enter_context(ConcurrencyContext())
545
1164
 
546
1165
  self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
547
1166
 
@@ -551,7 +1170,7 @@ class TaskRunEngine(Generic[P, R]):
551
1170
  task_run_name = _resolve_custom_task_run_name(
552
1171
  task=self.task, parameters=self.parameters
553
1172
  )
554
- self.client.set_task_run_name(
1173
+ await self.client.set_task_run_name(
555
1174
  task_run_id=self.task_run.id, name=task_run_name
556
1175
  )
557
1176
  self.logger.extra["task_run_name"] = task_run_name
@@ -562,19 +1181,19 @@ class TaskRunEngine(Generic[P, R]):
562
1181
  self._task_name_set = True
563
1182
  yield
564
1183
 
565
- @contextmanager
566
- def initialize_run(
1184
+ @asynccontextmanager
1185
+ async def initialize_run(
567
1186
  self,
568
1187
  task_run_id: Optional[UUID] = None,
569
1188
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
570
- ) -> Generator["TaskRunEngine", Any, Any]:
1189
+ ) -> AsyncGenerator["AsyncTaskRunEngine", Any]:
571
1190
  """
572
1191
  Enters a client context and creates a task run if needed.
573
1192
  """
574
1193
 
575
1194
  with hydrated_context(self.context):
576
- with ClientContext.get_or_create() as client_ctx:
577
- self._client = client_ctx.sync_client
1195
+ async with AsyncClientContext.get_or_create():
1196
+ self._client = get_client()
578
1197
  self._is_started = True
579
1198
  try:
580
1199
  if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
@@ -594,28 +1213,24 @@ class TaskRunEngine(Generic[P, R]):
594
1213
  self.task_run.name = task_run_name
595
1214
 
596
1215
  if not self.task_run:
597
- self.task_run = run_coro_as_sync(
598
- self.task.create_local_run(
599
- id=task_run_id,
600
- parameters=self.parameters,
601
- flow_run_context=FlowRunContext.get(),
602
- parent_task_run_context=TaskRunContext.get(),
603
- wait_for=self.wait_for,
604
- extra_task_inputs=dependencies,
605
- task_run_name=task_run_name,
606
- )
1216
+ self.task_run = await self.task.create_local_run(
1217
+ id=task_run_id,
1218
+ parameters=self.parameters,
1219
+ flow_run_context=FlowRunContext.get(),
1220
+ parent_task_run_context=TaskRunContext.get(),
1221
+ wait_for=self.wait_for,
1222
+ extra_task_inputs=dependencies,
1223
+ task_run_name=task_run_name,
607
1224
  )
608
1225
  else:
609
1226
  if not self.task_run:
610
- self.task_run = run_coro_as_sync(
611
- self.task.create_run(
612
- id=task_run_id,
613
- parameters=self.parameters,
614
- flow_run_context=FlowRunContext.get(),
615
- parent_task_run_context=TaskRunContext.get(),
616
- wait_for=self.wait_for,
617
- extra_task_inputs=dependencies,
618
- )
1227
+ self.task_run = await self.task.create_run(
1228
+ id=task_run_id,
1229
+ parameters=self.parameters,
1230
+ flow_run_context=FlowRunContext.get(),
1231
+ parent_task_run_context=TaskRunContext.get(),
1232
+ wait_for=self.wait_for,
1233
+ extra_task_inputs=dependencies,
619
1234
  )
620
1235
  # Emit an event to capture that the task run was in the `PENDING` state.
621
1236
  self._last_event = emit_task_run_state_change_event(
@@ -624,7 +1239,7 @@ class TaskRunEngine(Generic[P, R]):
624
1239
  validated_state=self.task_run.state,
625
1240
  )
626
1241
 
627
- with self.setup_run_context():
1242
+ async with self.setup_run_context():
628
1243
  # setup_run_context might update the task run name, so log creation here
629
1244
  self.logger.info(
630
1245
  f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
@@ -633,7 +1248,7 @@ class TaskRunEngine(Generic[P, R]):
633
1248
 
634
1249
  except TerminationSignal as exc:
635
1250
  # TerminationSignals are caught and handled as crashes
636
- self.handle_crash(exc)
1251
+ await self.handle_crash(exc)
637
1252
  raise exc
638
1253
 
639
1254
  except Exception:
@@ -649,60 +1264,19 @@ class TaskRunEngine(Generic[P, R]):
649
1264
  raise
650
1265
  except BaseException as exc:
651
1266
  # BaseExceptions are caught and handled as crashes
652
- self.handle_crash(exc)
1267
+ await self.handle_crash(exc)
653
1268
  raise
654
1269
  finally:
655
- # If debugging, use the more complete `repr` than the usual `str` description
656
- display_state = (
657
- repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
658
- )
659
- level = logging.INFO if self.state.is_completed() else logging.ERROR
660
- msg = f"Finished in state {display_state}"
661
- if self.state.is_pending():
662
- msg += (
663
- "\nPlease wait for all submitted tasks to complete"
664
- " before exiting your flow by calling `.wait()` on the "
665
- "`PrefectFuture` returned from your `.submit()` calls."
666
- )
667
- msg += dedent(
668
- """
669
-
670
- Example:
671
-
672
- from prefect import flow, task
673
-
674
- @task
675
- def say_hello(name):
676
- print f"Hello, {name}!"
677
-
678
- @flow
679
- def example_flow():
680
- future = say_hello.submit(name="Marvin)
681
- future.wait()
682
-
683
- example_flow()
684
- """
685
- )
686
- self.logger.log(
687
- level=level,
688
- msg=msg,
689
- )
690
-
1270
+ self.log_finished_message()
691
1271
  self._is_started = False
692
1272
  self._client = None
693
1273
 
694
- def is_running(self) -> bool:
695
- """Whether or not the engine is currently running a task."""
696
- if (task_run := getattr(self, "task_run", None)) is None:
697
- return False
698
- return task_run.state.is_running() or task_run.state.is_scheduled()
699
-
700
1274
  async def wait_until_ready(self):
701
1275
  """Waits until the scheduled time (if its the future), then enters Running."""
702
1276
  if scheduled_time := self.state.state_details.scheduled_time:
703
1277
  sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
704
1278
  await anyio.sleep(sleep_time if sleep_time > 0 else 0)
705
- self.set_state(
1279
+ await self.set_state(
706
1280
  Retrying() if self.state.name == "AwaitingRetry" else Running(),
707
1281
  force=True,
708
1282
  )
@@ -713,21 +1287,23 @@ class TaskRunEngine(Generic[P, R]):
713
1287
  #
714
1288
  # --------------------------
715
1289
 
716
- @contextmanager
717
- def start(
1290
+ @asynccontextmanager
1291
+ async def start(
718
1292
  self,
719
1293
  task_run_id: Optional[UUID] = None,
720
1294
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
721
- ) -> Generator[None, None, None]:
722
- with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
723
- self.begin_run()
1295
+ ) -> AsyncGenerator[None, None]:
1296
+ async with self.initialize_run(
1297
+ task_run_id=task_run_id, dependencies=dependencies
1298
+ ):
1299
+ await self.begin_run()
724
1300
  try:
725
1301
  yield
726
1302
  finally:
727
- self.call_hooks()
1303
+ await self.call_hooks()
728
1304
 
729
- @contextmanager
730
- def transaction_context(self) -> Generator[Transaction, None, None]:
1305
+ @asynccontextmanager
1306
+ async def transaction_context(self) -> AsyncGenerator[Transaction, None]:
731
1307
  result_factory = getattr(TaskRunContext.get(), "result_factory", None)
732
1308
 
733
1309
  # refresh cache setting is now repurposes as overwrite transaction record
@@ -744,13 +1320,12 @@ class TaskRunEngine(Generic[P, R]):
744
1320
  ) as txn:
745
1321
  yield txn
746
1322
 
747
- @contextmanager
748
- def run_context(self):
749
- timeout_context = timeout_async if self.task.isasync else timeout
1323
+ @asynccontextmanager
1324
+ async def run_context(self):
750
1325
  # reenter the run context to ensure it is up to date for every run
751
- with self.setup_run_context():
1326
+ async with self.setup_run_context():
752
1327
  try:
753
- with timeout_context(
1328
+ with timeout_async(
754
1329
  seconds=self.task.timeout_seconds,
755
1330
  timeout_exc_type=TaskRunTimeoutError,
756
1331
  ):
@@ -762,11 +1337,11 @@ class TaskRunEngine(Generic[P, R]):
762
1337
 
763
1338
  yield self
764
1339
  except TimeoutError as exc:
765
- self.handle_timeout(exc)
1340
+ await self.handle_timeout(exc)
766
1341
  except Exception as exc:
767
- self.handle_exception(exc)
1342
+ await self.handle_exception(exc)
768
1343
 
769
- def call_task_fn(
1344
+ async def call_task_fn(
770
1345
  self, transaction: Transaction
771
1346
  ) -> Union[R, Coroutine[Any, Any, R]]:
772
1347
  """
@@ -774,24 +1349,23 @@ class TaskRunEngine(Generic[P, R]):
774
1349
  task is async.
775
1350
  """
776
1351
  parameters = self.parameters or {}
777
- if self.task.isasync:
778
-
779
- async def _call_task_fn():
780
- if transaction.is_committed():
781
- result = transaction.read()
782
- else:
783
- result = await call_with_parameters(self.task.fn, parameters)
784
- self.handle_success(result, transaction=transaction)
785
- return result
786
-
787
- return _call_task_fn()
1352
+ if transaction.is_committed():
1353
+ result = transaction.read()
788
1354
  else:
789
- if transaction.is_committed():
790
- result = transaction.read()
1355
+ if (
1356
+ PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value()
1357
+ and self.task.tags
1358
+ ):
1359
+ # Acquire a concurrency slot for each tag, but only if a limit
1360
+ # matching the tag already exists.
1361
+ async with aconcurrency(
1362
+ list(self.task.tags), occupy=1, create_if_missing=False
1363
+ ):
1364
+ result = await call_with_parameters(self.task.fn, parameters)
791
1365
  else:
792
- result = call_with_parameters(self.task.fn, parameters)
793
- self.handle_success(result, transaction=transaction)
794
- return result
1366
+ result = await call_with_parameters(self.task.fn, parameters)
1367
+ await self.handle_success(result, transaction=transaction)
1368
+ return result
795
1369
 
796
1370
 
797
1371
  def run_task_sync(
@@ -804,7 +1378,7 @@ def run_task_sync(
804
1378
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
805
1379
  context: Optional[Dict[str, Any]] = None,
806
1380
  ) -> Union[R, State, None]:
807
- engine = TaskRunEngine[P, R](
1381
+ engine = SyncTaskRunEngine[P, R](
808
1382
  task=task,
809
1383
  parameters=parameters,
810
1384
  task_run=task_run,
@@ -831,7 +1405,7 @@ async def run_task_async(
831
1405
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
832
1406
  context: Optional[Dict[str, Any]] = None,
833
1407
  ) -> Union[R, State, None]:
834
- engine = TaskRunEngine[P, R](
1408
+ engine = AsyncTaskRunEngine[P, R](
835
1409
  task=task,
836
1410
  parameters=parameters,
837
1411
  task_run=task_run,
@@ -839,13 +1413,13 @@ async def run_task_async(
839
1413
  context=context,
840
1414
  )
841
1415
 
842
- with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1416
+ async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
843
1417
  while engine.is_running():
844
1418
  await engine.wait_until_ready()
845
- with engine.run_context(), engine.transaction_context() as txn:
1419
+ async with engine.run_context(), engine.transaction_context() as txn:
846
1420
  await engine.call_task_fn(txn)
847
1421
 
848
- return engine.state if return_type == "state" else engine.result()
1422
+ return engine.state if return_type == "state" else await engine.result()
849
1423
 
850
1424
 
851
1425
  def run_generator_task_sync(
@@ -861,7 +1435,7 @@ def run_generator_task_sync(
861
1435
  if return_type != "result":
862
1436
  raise ValueError("The return_type for a generator task must be 'result'")
863
1437
 
864
- engine = TaskRunEngine[P, R](
1438
+ engine = SyncTaskRunEngine[P, R](
865
1439
  task=task,
866
1440
  parameters=parameters,
867
1441
  task_run=task_run,
@@ -915,7 +1489,7 @@ async def run_generator_task_async(
915
1489
  ) -> AsyncGenerator[R, None]:
916
1490
  if return_type != "result":
917
1491
  raise ValueError("The return_type for a generator task must be 'result'")
918
- engine = TaskRunEngine[P, R](
1492
+ engine = AsyncTaskRunEngine[P, R](
919
1493
  task=task,
920
1494
  parameters=parameters,
921
1495
  task_run=task_run,
@@ -923,10 +1497,10 @@ async def run_generator_task_async(
923
1497
  context=context,
924
1498
  )
925
1499
 
926
- with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1500
+ async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
927
1501
  while engine.is_running():
928
1502
  await engine.wait_until_ready()
929
- with engine.run_context(), engine.transaction_context() as txn:
1503
+ async with engine.run_context(), engine.transaction_context() as txn:
930
1504
  # TODO: generators should default to commit_mode=OFF
931
1505
  # because they are dynamic by definition
932
1506
  # for now we just prevent this branch explicitly
@@ -950,13 +1524,13 @@ async def run_generator_task_async(
950
1524
  link_state_to_result(engine.state, gen_result)
951
1525
  yield gen_result
952
1526
  except (StopAsyncIteration, GeneratorExit) as exc:
953
- engine.handle_success(None, transaction=txn)
1527
+ await engine.handle_success(None, transaction=txn)
954
1528
  if isinstance(exc, GeneratorExit):
955
1529
  gen.throw(exc)
956
1530
 
957
1531
  # async generators can't return, but we can raise failures here
958
1532
  if engine.state.is_failed():
959
- engine.result()
1533
+ await engine.result()
960
1534
 
961
1535
 
962
1536
  def run_task(