prefect-client 3.0.0rc13__py3-none-any.whl → 3.0.0rc15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. prefect/_internal/compatibility/deprecated.py +0 -53
  2. prefect/blocks/core.py +132 -4
  3. prefect/blocks/notifications.py +26 -3
  4. prefect/client/base.py +30 -24
  5. prefect/client/orchestration.py +121 -47
  6. prefect/client/utilities.py +4 -4
  7. prefect/concurrency/asyncio.py +48 -7
  8. prefect/concurrency/context.py +24 -0
  9. prefect/concurrency/services.py +24 -8
  10. prefect/concurrency/sync.py +30 -3
  11. prefect/context.py +85 -24
  12. prefect/events/clients.py +93 -60
  13. prefect/events/utilities.py +0 -2
  14. prefect/events/worker.py +9 -2
  15. prefect/flow_engine.py +6 -3
  16. prefect/flows.py +176 -12
  17. prefect/futures.py +84 -7
  18. prefect/profiles.toml +16 -2
  19. prefect/runner/runner.py +6 -1
  20. prefect/runner/storage.py +4 -0
  21. prefect/settings.py +108 -14
  22. prefect/task_engine.py +901 -285
  23. prefect/task_runs.py +24 -1
  24. prefect/task_worker.py +7 -1
  25. prefect/tasks.py +9 -5
  26. prefect/utilities/asyncutils.py +0 -6
  27. prefect/utilities/callables.py +5 -3
  28. prefect/utilities/engine.py +3 -0
  29. prefect/utilities/importtools.py +138 -58
  30. prefect/utilities/schema_tools/validation.py +30 -0
  31. prefect/utilities/services.py +32 -0
  32. {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/METADATA +39 -39
  33. {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/RECORD +36 -35
  34. {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/WHEEL +1 -1
  35. {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/LICENSE +0 -0
  36. {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/top_level.txt +0 -0
prefect/task_engine.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
3
  import threading
4
4
  import time
5
5
  from asyncio import CancelledError
6
- from contextlib import ExitStack, contextmanager
6
+ from contextlib import ExitStack, asynccontextmanager, contextmanager
7
7
  from dataclasses import dataclass, field
8
8
  from functools import wraps
9
9
  from textwrap import dedent
@@ -31,12 +31,16 @@ import pendulum
31
31
  from typing_extensions import ParamSpec
32
32
 
33
33
  from prefect import Task
34
- from prefect.client.orchestration import SyncPrefectClient
34
+ from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
35
35
  from prefect.client.schemas import TaskRun
36
36
  from prefect.client.schemas.objects import State, TaskRunInput
37
+ from prefect.concurrency.asyncio import concurrency as aconcurrency
38
+ from prefect.concurrency.context import ConcurrencyContext
39
+ from prefect.concurrency.sync import concurrency
37
40
  from prefect.context import (
38
- ClientContext,
41
+ AsyncClientContext,
39
42
  FlowRunContext,
43
+ SyncClientContext,
40
44
  TaskRunContext,
41
45
  hydrated_context,
42
46
  )
@@ -59,6 +63,7 @@ from prefect.settings import (
59
63
  )
60
64
  from prefect.states import (
61
65
  AwaitingRetry,
66
+ Completed,
62
67
  Failed,
63
68
  Paused,
64
69
  Pending,
@@ -77,6 +82,7 @@ from prefect.utilities.engine import (
77
82
  _get_hook_name,
78
83
  emit_task_run_state_change_event,
79
84
  link_state_to_result,
85
+ propose_state,
80
86
  propose_state_sync,
81
87
  resolve_to_final_result,
82
88
  )
@@ -86,47 +92,767 @@ from prefect.utilities.timeout import timeout, timeout_async
86
92
  P = ParamSpec("P")
87
93
  R = TypeVar("R")
88
94
 
95
+ BACKOFF_MAX = 10
96
+
89
97
 
90
98
  class TaskRunTimeoutError(TimeoutError):
91
99
  """Raised when a task run exceeds its timeout."""
92
100
 
93
101
 
94
- @dataclass
95
- class TaskRunEngine(Generic[P, R]):
96
- task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
97
- logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
98
- parameters: Optional[Dict[str, Any]] = None
99
- task_run: Optional[TaskRun] = None
100
- retries: int = 0
101
- wait_for: Optional[Iterable[PrefectFuture]] = None
102
- context: Optional[Dict[str, Any]] = None
103
- # holds the return value from the user code
104
- _return_value: Union[R, Type[NotSet]] = NotSet
105
- # holds the exception raised by the user code, if any
106
- _raised: Union[Exception, Type[NotSet]] = NotSet
107
- _initial_run_context: Optional[TaskRunContext] = None
108
- _is_started: bool = False
109
- _client: Optional[SyncPrefectClient] = None
110
- _task_name_set: bool = False
111
- _last_event: Optional[PrefectEvent] = None
102
+ @dataclass
103
+ class BaseTaskRunEngine(Generic[P, R]):
104
+ task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
105
+ logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
106
+ parameters: Optional[Dict[str, Any]] = None
107
+ task_run: Optional[TaskRun] = None
108
+ retries: int = 0
109
+ wait_for: Optional[Iterable[PrefectFuture]] = None
110
+ context: Optional[Dict[str, Any]] = None
111
+ # holds the return value from the user code
112
+ _return_value: Union[R, Type[NotSet]] = NotSet
113
+ # holds the exception raised by the user code, if any
114
+ _raised: Union[Exception, Type[NotSet]] = NotSet
115
+ _initial_run_context: Optional[TaskRunContext] = None
116
+ _is_started: bool = False
117
+ _task_name_set: bool = False
118
+ _last_event: Optional[PrefectEvent] = None
119
+
120
+ def __post_init__(self):
121
+ if self.parameters is None:
122
+ self.parameters = {}
123
+
124
+ @property
125
+ def state(self) -> State:
126
+ if not self.task_run:
127
+ raise ValueError("Task run is not set")
128
+ return self.task_run.state
129
+
130
+ def is_cancelled(self) -> bool:
131
+ if (
132
+ self.context
133
+ and "cancel_event" in self.context
134
+ and isinstance(self.context["cancel_event"], threading.Event)
135
+ ):
136
+ return self.context["cancel_event"].is_set()
137
+ return False
138
+
139
+ def compute_transaction_key(self) -> Optional[str]:
140
+ key = None
141
+ if self.task.cache_policy:
142
+ flow_run_context = FlowRunContext.get()
143
+ task_run_context = TaskRunContext.get()
144
+
145
+ if flow_run_context:
146
+ parameters = flow_run_context.parameters
147
+ else:
148
+ parameters = None
149
+
150
+ key = self.task.cache_policy.compute_key(
151
+ task_ctx=task_run_context,
152
+ inputs=self.parameters,
153
+ flow_parameters=parameters,
154
+ )
155
+ elif self.task.result_storage_key is not None:
156
+ key = _format_user_supplied_storage_key(self.task.result_storage_key)
157
+ return key
158
+
159
+ def _resolve_parameters(self):
160
+ if not self.parameters:
161
+ return {}
162
+
163
+ resolved_parameters = {}
164
+ for parameter, value in self.parameters.items():
165
+ try:
166
+ resolved_parameters[parameter] = visit_collection(
167
+ value,
168
+ visit_fn=resolve_to_final_result,
169
+ return_data=True,
170
+ max_depth=-1,
171
+ remove_annotations=True,
172
+ context={},
173
+ )
174
+ except UpstreamTaskError:
175
+ raise
176
+ except Exception as exc:
177
+ raise PrefectException(
178
+ f"Failed to resolve inputs in parameter {parameter!r}. If your"
179
+ " parameter type is not supported, consider using the `quote`"
180
+ " annotation to skip resolution of inputs."
181
+ ) from exc
182
+
183
+ self.parameters = resolved_parameters
184
+
185
+ def _wait_for_dependencies(self):
186
+ if not self.wait_for:
187
+ return
188
+
189
+ visit_collection(
190
+ self.wait_for,
191
+ visit_fn=resolve_to_final_result,
192
+ return_data=False,
193
+ max_depth=-1,
194
+ remove_annotations=True,
195
+ context={"current_task_run": self.task_run, "current_task": self.task},
196
+ )
197
+
198
+ def record_terminal_state_timing(self, state: State) -> None:
199
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
200
+ if self.task_run.start_time and not self.task_run.end_time:
201
+ self.task_run.end_time = state.timestamp
202
+
203
+ if self.task_run.state.is_running():
204
+ self.task_run.total_run_time += (
205
+ state.timestamp - self.task_run.state.timestamp
206
+ )
207
+
208
+ def is_running(self) -> bool:
209
+ """Whether or not the engine is currently running a task."""
210
+ if (task_run := getattr(self, "task_run", None)) is None:
211
+ return False
212
+ return task_run.state.is_running() or task_run.state.is_scheduled()
213
+
214
+ def log_finished_message(self):
215
+ # If debugging, use the more complete `repr` than the usual `str` description
216
+ display_state = repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
217
+ level = logging.INFO if self.state.is_completed() else logging.ERROR
218
+ msg = f"Finished in state {display_state}"
219
+ if self.state.is_pending():
220
+ msg += (
221
+ "\nPlease wait for all submitted tasks to complete"
222
+ " before exiting your flow by calling `.wait()` on the "
223
+ "`PrefectFuture` returned from your `.submit()` calls."
224
+ )
225
+ msg += dedent(
226
+ """
227
+
228
+ Example:
229
+
230
+ from prefect import flow, task
231
+
232
+ @task
233
+ def say_hello(name):
234
+ print f"Hello, {name}!"
235
+
236
+ @flow
237
+ def example_flow():
238
+ future = say_hello.submit(name="Marvin)
239
+ future.wait()
240
+
241
+ example_flow()
242
+ """
243
+ )
244
+ self.logger.log(
245
+ level=level,
246
+ msg=msg,
247
+ )
248
+
249
+ def handle_rollback(self, txn: Transaction) -> None:
250
+ assert self.task_run is not None
251
+
252
+ rolled_back_state = Completed(
253
+ name="RolledBack",
254
+ message="Task rolled back as part of transaction",
255
+ )
256
+
257
+ self._last_event = emit_task_run_state_change_event(
258
+ task_run=self.task_run,
259
+ initial_state=self.state,
260
+ validated_state=rolled_back_state,
261
+ follows=self._last_event,
262
+ )
263
+
264
+
265
+ @dataclass
266
+ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
267
+ _client: Optional[SyncPrefectClient] = None
268
+
269
+ @property
270
+ def client(self) -> SyncPrefectClient:
271
+ if not self._is_started or self._client is None:
272
+ raise RuntimeError("Engine has not started.")
273
+ return self._client
274
+
275
+ def can_retry(self, exc: Exception) -> bool:
276
+ retry_condition: Optional[
277
+ Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
278
+ ] = self.task.retry_condition_fn
279
+ if not self.task_run:
280
+ raise ValueError("Task run is not set")
281
+ try:
282
+ self.logger.debug(
283
+ f"Running `retry_condition_fn` check {retry_condition!r} for task"
284
+ f" {self.task.name!r}"
285
+ )
286
+ state = Failed(
287
+ data=exc,
288
+ message=f"Task run encountered unexpected exception: {repr(exc)}",
289
+ )
290
+ if inspect.iscoroutinefunction(retry_condition):
291
+ should_retry = run_coro_as_sync(
292
+ retry_condition(self.task, self.task_run, state)
293
+ )
294
+ elif inspect.isfunction(retry_condition):
295
+ should_retry = retry_condition(self.task, self.task_run, state)
296
+ else:
297
+ should_retry = not retry_condition
298
+ return should_retry
299
+ except Exception:
300
+ self.logger.error(
301
+ (
302
+ "An error was encountered while running `retry_condition_fn` check"
303
+ f" '{retry_condition!r}' for task {self.task.name!r}"
304
+ ),
305
+ exc_info=True,
306
+ )
307
+ return False
308
+
309
+ def call_hooks(self, state: Optional[State] = None):
310
+ if state is None:
311
+ state = self.state
312
+ task = self.task
313
+ task_run = self.task_run
314
+
315
+ if not task_run:
316
+ raise ValueError("Task run is not set")
317
+
318
+ if state.is_failed() and task.on_failure_hooks:
319
+ hooks = task.on_failure_hooks
320
+ elif state.is_completed() and task.on_completion_hooks:
321
+ hooks = task.on_completion_hooks
322
+ else:
323
+ hooks = None
324
+
325
+ for hook in hooks or []:
326
+ hook_name = _get_hook_name(hook)
327
+
328
+ try:
329
+ self.logger.info(
330
+ f"Running hook {hook_name!r} in response to entering state"
331
+ f" {state.name!r}"
332
+ )
333
+ result = hook(task, task_run, state)
334
+ if inspect.isawaitable(result):
335
+ run_coro_as_sync(result)
336
+ except Exception:
337
+ self.logger.error(
338
+ f"An error was encountered while running hook {hook_name!r}",
339
+ exc_info=True,
340
+ )
341
+ else:
342
+ self.logger.info(f"Hook {hook_name!r} finished running successfully")
343
+
344
+ def begin_run(self):
345
+ try:
346
+ self._resolve_parameters()
347
+ self._wait_for_dependencies()
348
+ except UpstreamTaskError as upstream_exc:
349
+ state = self.set_state(
350
+ Pending(
351
+ name="NotReady",
352
+ message=str(upstream_exc),
353
+ ),
354
+ # if orchestrating a run already in a pending state, force orchestration to
355
+ # update the state name
356
+ force=self.state.is_pending(),
357
+ )
358
+ return
359
+
360
+ new_state = Running()
361
+
362
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
363
+ self.task_run.start_time = new_state.timestamp
364
+ self.task_run.run_count += 1
365
+
366
+ flow_run_context = FlowRunContext.get()
367
+ if flow_run_context:
368
+ # Carry forward any task run information from the flow run
369
+ flow_run = flow_run_context.flow_run
370
+ self.task_run.flow_run_run_count = flow_run.run_count
371
+
372
+ state = self.set_state(new_state)
373
+
374
+ # TODO: this is temporary until the API stops rejecting state transitions
375
+ # and the client / transaction store becomes the source of truth
376
+ # this is a bandaid caused by the API storing a Completed state with a bad
377
+ # result reference that no longer exists
378
+ if state.is_completed():
379
+ try:
380
+ state.result(retry_result_failure=False, _sync=True)
381
+ except Exception:
382
+ state = self.set_state(new_state, force=True)
383
+
384
+ backoff_count = 0
385
+
386
+ # TODO: Could this listen for state change events instead of polling?
387
+ while state.is_pending() or state.is_paused():
388
+ if backoff_count < BACKOFF_MAX:
389
+ backoff_count += 1
390
+ interval = clamped_poisson_interval(
391
+ average_interval=backoff_count, clamping_factor=0.3
392
+ )
393
+ time.sleep(interval)
394
+ state = self.set_state(new_state)
395
+
396
+ def set_state(self, state: State, force: bool = False) -> State:
397
+ last_state = self.state
398
+ if not self.task_run:
399
+ raise ValueError("Task run is not set")
400
+
401
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
402
+ self.task_run.state = new_state = state
403
+
404
+ # Ensure that the state_details are populated with the current run IDs
405
+ new_state.state_details.task_run_id = self.task_run.id
406
+ new_state.state_details.flow_run_id = self.task_run.flow_run_id
407
+
408
+ # Predictively update the de-normalized task_run.state_* attributes
409
+ self.task_run.state_id = new_state.id
410
+ self.task_run.state_type = new_state.type
411
+ self.task_run.state_name = new_state.name
412
+
413
+ if new_state.is_final():
414
+ if (
415
+ isinstance(state.data, BaseResult)
416
+ and state.data.has_cached_object()
417
+ ):
418
+ # Avoid fetching the result unless it is cached, otherwise we defeat
419
+ # the purpose of disabling `cache_result_in_memory`
420
+ result = state.result(raise_on_failure=False, fetch=True)
421
+ if inspect.isawaitable(result):
422
+ result = run_coro_as_sync(result)
423
+ else:
424
+ result = state.data
425
+
426
+ link_state_to_result(state, result)
427
+
428
+ else:
429
+ try:
430
+ new_state = propose_state_sync(
431
+ self.client, state, task_run_id=self.task_run.id, force=force
432
+ )
433
+ except Pause as exc:
434
+ # We shouldn't get a pause signal without a state, but if this happens,
435
+ # just use a Paused state to assume an in-process pause.
436
+ new_state = exc.state if exc.state else Paused()
437
+ if new_state.state_details.pause_reschedule:
438
+ # If we're being asked to pause and reschedule, we should exit the
439
+ # task and expect to be resumed later.
440
+ raise
441
+
442
+ # currently this is a hack to keep a reference to the state object
443
+ # that has an in-memory result attached to it; using the API state
444
+ # could result in losing that reference
445
+ self.task_run.state = new_state
446
+
447
+ # emit a state change event
448
+ self._last_event = emit_task_run_state_change_event(
449
+ task_run=self.task_run,
450
+ initial_state=last_state,
451
+ validated_state=self.task_run.state,
452
+ follows=self._last_event,
453
+ )
454
+
455
+ return new_state
456
+
457
+ def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
458
+ if self._return_value is not NotSet:
459
+ # if the return value is a BaseResult, we need to fetch it
460
+ if isinstance(self._return_value, BaseResult):
461
+ _result = self._return_value.get()
462
+ if inspect.isawaitable(_result):
463
+ _result = run_coro_as_sync(_result)
464
+ return _result
465
+
466
+ # otherwise, return the value as is
467
+ return self._return_value
468
+
469
+ if self._raised is not NotSet:
470
+ # if the task raised an exception, raise it
471
+ if raise_on_failure:
472
+ raise self._raised
473
+
474
+ # otherwise, return the exception
475
+ return self._raised
476
+
477
+ def handle_success(self, result: R, transaction: Transaction) -> R:
478
+ result_factory = getattr(TaskRunContext.get(), "result_factory", None)
479
+ if result_factory is None:
480
+ raise ValueError("Result factory is not set")
481
+
482
+ if self.task.cache_expiration is not None:
483
+ expiration = pendulum.now("utc") + self.task.cache_expiration
484
+ else:
485
+ expiration = None
486
+
487
+ terminal_state = run_coro_as_sync(
488
+ return_value_to_state(
489
+ result,
490
+ result_factory=result_factory,
491
+ key=transaction.key,
492
+ expiration=expiration,
493
+ # defer persistence to transaction commit
494
+ defer_persistence=True,
495
+ )
496
+ )
497
+ transaction.stage(
498
+ terminal_state.data,
499
+ on_rollback_hooks=[self.handle_rollback]
500
+ + [
501
+ _with_transaction_hook_logging(hook, "rollback", self.logger)
502
+ for hook in self.task.on_rollback_hooks
503
+ ],
504
+ on_commit_hooks=[
505
+ _with_transaction_hook_logging(hook, "commit", self.logger)
506
+ for hook in self.task.on_commit_hooks
507
+ ],
508
+ )
509
+ if transaction.is_committed():
510
+ terminal_state.name = "Cached"
511
+
512
+ self.record_terminal_state_timing(terminal_state)
513
+ self.set_state(terminal_state)
514
+ self._return_value = result
515
+ return result
516
+
517
+ def handle_retry(self, exc: Exception) -> bool:
518
+ """Handle any task run retries.
519
+
520
+ - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
521
+ - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
522
+ - If the task has no retries left, or the retry condition is not met, return False.
523
+ """
524
+ if self.retries < self.task.retries and self.can_retry(exc):
525
+ if self.task.retry_delay_seconds:
526
+ delay = (
527
+ self.task.retry_delay_seconds[
528
+ min(self.retries, len(self.task.retry_delay_seconds) - 1)
529
+ ] # repeat final delay value if attempts exceed specified delays
530
+ if isinstance(self.task.retry_delay_seconds, Sequence)
531
+ else self.task.retry_delay_seconds
532
+ )
533
+ new_state = AwaitingRetry(
534
+ scheduled_time=pendulum.now("utc").add(seconds=delay)
535
+ )
536
+ else:
537
+ delay = None
538
+ new_state = Retrying()
539
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
540
+ self.task_run.run_count += 1
541
+
542
+ self.logger.info(
543
+ "Task run failed with exception: %r - " "Retry %s/%s will start %s",
544
+ exc,
545
+ self.retries + 1,
546
+ self.task.retries,
547
+ str(delay) + " second(s) from now" if delay else "immediately",
548
+ )
549
+
550
+ self.set_state(new_state, force=True)
551
+ self.retries = self.retries + 1
552
+ return True
553
+ elif self.retries >= self.task.retries:
554
+ self.logger.error(
555
+ "Task run failed with exception: %r - Retries are exhausted",
556
+ exc,
557
+ exc_info=True,
558
+ )
559
+ return False
560
+
561
+ return False
562
+
563
+ def handle_exception(self, exc: Exception) -> None:
564
+ # If the task fails, and we have retries left, set the task to retrying.
565
+ if not self.handle_retry(exc):
566
+ # If the task has no retries left, or the retry condition is not met, set the task to failed.
567
+ context = TaskRunContext.get()
568
+ state = run_coro_as_sync(
569
+ exception_to_failed_state(
570
+ exc,
571
+ message="Task run encountered an exception",
572
+ result_factory=getattr(context, "result_factory", None),
573
+ )
574
+ )
575
+ self.record_terminal_state_timing(state)
576
+ self.set_state(state)
577
+ self._raised = exc
578
+
579
+ def handle_timeout(self, exc: TimeoutError) -> None:
580
+ if not self.handle_retry(exc):
581
+ if isinstance(exc, TaskRunTimeoutError):
582
+ message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
583
+ else:
584
+ message = f"Task run failed due to timeout: {exc!r}"
585
+ self.logger.error(message)
586
+ state = Failed(
587
+ data=exc,
588
+ message=message,
589
+ name="TimedOut",
590
+ )
591
+ self.set_state(state)
592
+ self._raised = exc
593
+
594
+ def handle_crash(self, exc: BaseException) -> None:
595
+ state = run_coro_as_sync(exception_to_crashed_state(exc))
596
+ self.logger.error(f"Crash detected! {state.message}")
597
+ self.logger.debug("Crash details:", exc_info=exc)
598
+ self.record_terminal_state_timing(state)
599
+ self.set_state(state, force=True)
600
+ self._raised = exc
601
+
602
+ @contextmanager
603
+ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
604
+ from prefect.utilities.engine import (
605
+ _resolve_custom_task_run_name,
606
+ should_log_prints,
607
+ )
608
+
609
+ if client is None:
610
+ client = self.client
611
+ if not self.task_run:
612
+ raise ValueError("Task run is not set")
613
+
614
+ if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
615
+ self.task_run = client.read_task_run(self.task_run.id)
616
+ with ExitStack() as stack:
617
+ if log_prints := should_log_prints(self.task):
618
+ stack.enter_context(patch_print())
619
+ stack.enter_context(
620
+ TaskRunContext(
621
+ task=self.task,
622
+ log_prints=log_prints,
623
+ task_run=self.task_run,
624
+ parameters=self.parameters,
625
+ result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore
626
+ client=client,
627
+ )
628
+ )
629
+ stack.enter_context(ConcurrencyContext())
630
+
631
+ self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
632
+
633
+ if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
634
+ # update the task run name if necessary
635
+ if not self._task_name_set and self.task.task_run_name:
636
+ task_run_name = _resolve_custom_task_run_name(
637
+ task=self.task, parameters=self.parameters
638
+ )
639
+ self.client.set_task_run_name(
640
+ task_run_id=self.task_run.id, name=task_run_name
641
+ )
642
+ self.logger.extra["task_run_name"] = task_run_name
643
+ self.logger.debug(
644
+ f"Renamed task run {self.task_run.name!r} to {task_run_name!r}"
645
+ )
646
+ self.task_run.name = task_run_name
647
+ self._task_name_set = True
648
+ yield
649
+
650
+ @contextmanager
651
+ def initialize_run(
652
+ self,
653
+ task_run_id: Optional[UUID] = None,
654
+ dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
655
+ ) -> Generator["SyncTaskRunEngine", Any, Any]:
656
+ """
657
+ Enters a client context and creates a task run if needed.
658
+ """
659
+
660
+ with hydrated_context(self.context):
661
+ with SyncClientContext.get_or_create() as client_ctx:
662
+ self._client = client_ctx.client
663
+ self._is_started = True
664
+ try:
665
+ if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
666
+ from prefect.utilities.engine import (
667
+ _resolve_custom_task_run_name,
668
+ )
669
+
670
+ task_run_name = (
671
+ _resolve_custom_task_run_name(
672
+ task=self.task, parameters=self.parameters
673
+ )
674
+ if self.task.task_run_name
675
+ else None
676
+ )
677
+
678
+ if self.task_run and task_run_name:
679
+ self.task_run.name = task_run_name
680
+
681
+ if not self.task_run:
682
+ self.task_run = run_coro_as_sync(
683
+ self.task.create_local_run(
684
+ id=task_run_id,
685
+ parameters=self.parameters,
686
+ flow_run_context=FlowRunContext.get(),
687
+ parent_task_run_context=TaskRunContext.get(),
688
+ wait_for=self.wait_for,
689
+ extra_task_inputs=dependencies,
690
+ task_run_name=task_run_name,
691
+ )
692
+ )
693
+ # Emit an event to capture that the task run was in the `PENDING` state.
694
+ self._last_event = emit_task_run_state_change_event(
695
+ task_run=self.task_run,
696
+ initial_state=None,
697
+ validated_state=self.task_run.state,
698
+ )
699
+ else:
700
+ if not self.task_run:
701
+ self.task_run = run_coro_as_sync(
702
+ self.task.create_run(
703
+ id=task_run_id,
704
+ parameters=self.parameters,
705
+ flow_run_context=FlowRunContext.get(),
706
+ parent_task_run_context=TaskRunContext.get(),
707
+ wait_for=self.wait_for,
708
+ extra_task_inputs=dependencies,
709
+ )
710
+ )
711
+ # Emit an event to capture that the task run was in the `PENDING` state.
712
+ self._last_event = emit_task_run_state_change_event(
713
+ task_run=self.task_run,
714
+ initial_state=None,
715
+ validated_state=self.task_run.state,
716
+ )
717
+
718
+ with self.setup_run_context():
719
+ # setup_run_context might update the task run name, so log creation here
720
+ self.logger.info(
721
+ f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
722
+ )
723
+ yield self
724
+
725
+ except TerminationSignal as exc:
726
+ # TerminationSignals are caught and handled as crashes
727
+ self.handle_crash(exc)
728
+ raise exc
729
+
730
+ except Exception:
731
+ # regular exceptions are caught and re-raised to the user
732
+ raise
733
+ except (Pause, Abort) as exc:
734
+ # Do not capture internal signals as crashes
735
+ if isinstance(exc, Abort):
736
+ self.logger.error("Task run was aborted: %s", exc)
737
+ raise
738
+ except GeneratorExit:
739
+ # Do not capture generator exits as crashes
740
+ raise
741
+ except BaseException as exc:
742
+ # BaseExceptions are caught and handled as crashes
743
+ self.handle_crash(exc)
744
+ raise
745
+ finally:
746
+ self.log_finished_message()
747
+ self._is_started = False
748
+ self._client = None
749
+
750
+ async def wait_until_ready(self):
751
+ """Waits until the scheduled time (if its the future), then enters Running."""
752
+ if scheduled_time := self.state.state_details.scheduled_time:
753
+ sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
754
+ await anyio.sleep(sleep_time if sleep_time > 0 else 0)
755
+ self.set_state(
756
+ Retrying() if self.state.name == "AwaitingRetry" else Running(),
757
+ force=True,
758
+ )
759
+
760
+ # --------------------------
761
+ #
762
+ # The following methods compose the main task run loop
763
+ #
764
+ # --------------------------
765
+
766
+ @contextmanager
767
+ def start(
768
+ self,
769
+ task_run_id: Optional[UUID] = None,
770
+ dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
771
+ ) -> Generator[None, None, None]:
772
+ with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
773
+ self.begin_run()
774
+ try:
775
+ yield
776
+ finally:
777
+ self.call_hooks()
778
+
779
+ @contextmanager
780
+ def transaction_context(self) -> Generator[Transaction, None, None]:
781
+ result_factory = getattr(TaskRunContext.get(), "result_factory", None)
782
+
783
+ # refresh cache setting is now repurposes as overwrite transaction record
784
+ overwrite = (
785
+ self.task.refresh_cache
786
+ if self.task.refresh_cache is not None
787
+ else PREFECT_TASKS_REFRESH_CACHE.value()
788
+ )
789
+ with transaction(
790
+ key=self.compute_transaction_key(),
791
+ store=ResultFactoryStore(result_factory=result_factory),
792
+ overwrite=overwrite,
793
+ logger=self.logger,
794
+ ) as txn:
795
+ yield txn
796
+
797
+ @contextmanager
798
+ def run_context(self):
799
+ # reenter the run context to ensure it is up to date for every run
800
+ with self.setup_run_context():
801
+ try:
802
+ with timeout(
803
+ seconds=self.task.timeout_seconds,
804
+ timeout_exc_type=TaskRunTimeoutError,
805
+ ):
806
+ self.logger.debug(
807
+ f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
808
+ )
809
+ if self.is_cancelled():
810
+ raise CancelledError("Task run cancelled by the task runner")
811
+
812
+ yield self
813
+ except TimeoutError as exc:
814
+ self.handle_timeout(exc)
815
+ except Exception as exc:
816
+ self.handle_exception(exc)
817
+
818
+ def call_task_fn(
819
+ self, transaction: Transaction
820
+ ) -> Union[R, Coroutine[Any, Any, R]]:
821
+ """
822
+ Convenience method to call the task function. Returns a coroutine if the
823
+ task is async.
824
+ """
825
+ parameters = self.parameters or {}
826
+ if transaction.is_committed():
827
+ result = transaction.read()
828
+ else:
829
+ if (
830
+ PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value()
831
+ and self.task.tags
832
+ ):
833
+ # Acquire a concurrency slot for each tag, but only if a limit
834
+ # matching the tag already exists.
835
+ with concurrency(
836
+ list(self.task.tags), occupy=1, create_if_missing=False
837
+ ):
838
+ result = call_with_parameters(self.task.fn, parameters)
839
+ else:
840
+ result = call_with_parameters(self.task.fn, parameters)
841
+ self.handle_success(result, transaction=transaction)
842
+ return result
843
+
112
844
 
113
- def __post_init__(self):
114
- if self.parameters is None:
115
- self.parameters = {}
845
+ @dataclass
846
+ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
847
+ _client: Optional[PrefectClient] = None
116
848
 
117
849
  @property
118
- def client(self) -> SyncPrefectClient:
850
+ def client(self) -> PrefectClient:
119
851
  if not self._is_started or self._client is None:
120
852
  raise RuntimeError("Engine has not started.")
121
853
  return self._client
122
854
 
123
- @property
124
- def state(self) -> State:
125
- if not self.task_run:
126
- raise ValueError("Task run is not set")
127
- return self.task_run.state
128
-
129
- def can_retry(self, exc: Exception) -> bool:
855
+ async def can_retry(self, exc: Exception) -> bool:
130
856
  retry_condition: Optional[
131
857
  Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
132
858
  ] = self.task.retry_condition_fn
@@ -142,14 +868,13 @@ class TaskRunEngine(Generic[P, R]):
142
868
  message=f"Task run encountered unexpected exception: {repr(exc)}",
143
869
  )
144
870
  if inspect.iscoroutinefunction(retry_condition):
145
- should_retry = run_coro_as_sync(
146
- retry_condition(self.task, self.task_run, state)
147
- )
871
+ should_retry = await retry_condition(self.task, self.task_run, state)
148
872
  elif inspect.isfunction(retry_condition):
149
873
  should_retry = retry_condition(self.task, self.task_run, state)
150
874
  else:
151
875
  should_retry = not retry_condition
152
876
  return should_retry
877
+
153
878
  except Exception:
154
879
  self.logger.error(
155
880
  (
@@ -160,16 +885,7 @@ class TaskRunEngine(Generic[P, R]):
160
885
  )
161
886
  return False
162
887
 
163
- def is_cancelled(self) -> bool:
164
- if (
165
- self.context
166
- and "cancel_event" in self.context
167
- and isinstance(self.context["cancel_event"], threading.Event)
168
- ):
169
- return self.context["cancel_event"].is_set()
170
- return False
171
-
172
- def call_hooks(self, state: Optional[State] = None):
888
+ async def call_hooks(self, state: Optional[State] = None):
173
889
  if state is None:
174
890
  state = self.state
175
891
  task = self.task
@@ -195,7 +911,7 @@ class TaskRunEngine(Generic[P, R]):
195
911
  )
196
912
  result = hook(task, task_run, state)
197
913
  if inspect.isawaitable(result):
198
- run_coro_as_sync(result)
914
+ await result
199
915
  except Exception:
200
916
  self.logger.error(
201
917
  f"An error was encountered while running hook {hook_name!r}",
@@ -204,71 +920,12 @@ class TaskRunEngine(Generic[P, R]):
204
920
  else:
205
921
  self.logger.info(f"Hook {hook_name!r} finished running successfully")
206
922
 
207
- def compute_transaction_key(self) -> Optional[str]:
208
- key = None
209
- if self.task.cache_policy:
210
- flow_run_context = FlowRunContext.get()
211
- task_run_context = TaskRunContext.get()
212
-
213
- if flow_run_context:
214
- parameters = flow_run_context.parameters
215
- else:
216
- parameters = None
217
-
218
- key = self.task.cache_policy.compute_key(
219
- task_ctx=task_run_context,
220
- inputs=self.parameters,
221
- flow_parameters=parameters,
222
- )
223
- elif self.task.result_storage_key is not None:
224
- key = _format_user_supplied_storage_key(self.task.result_storage_key)
225
- return key
226
-
227
- def _resolve_parameters(self):
228
- if not self.parameters:
229
- return {}
230
-
231
- resolved_parameters = {}
232
- for parameter, value in self.parameters.items():
233
- try:
234
- resolved_parameters[parameter] = visit_collection(
235
- value,
236
- visit_fn=resolve_to_final_result,
237
- return_data=True,
238
- max_depth=-1,
239
- remove_annotations=True,
240
- context={},
241
- )
242
- except UpstreamTaskError:
243
- raise
244
- except Exception as exc:
245
- raise PrefectException(
246
- f"Failed to resolve inputs in parameter {parameter!r}. If your"
247
- " parameter type is not supported, consider using the `quote`"
248
- " annotation to skip resolution of inputs."
249
- ) from exc
250
-
251
- self.parameters = resolved_parameters
252
-
253
- def _wait_for_dependencies(self):
254
- if not self.wait_for:
255
- return
256
-
257
- visit_collection(
258
- self.wait_for,
259
- visit_fn=resolve_to_final_result,
260
- return_data=False,
261
- max_depth=-1,
262
- remove_annotations=True,
263
- context={"current_task_run": self.task_run, "current_task": self.task},
264
- )
265
-
266
- def begin_run(self):
923
+ async def begin_run(self):
267
924
  try:
268
925
  self._resolve_parameters()
269
926
  self._wait_for_dependencies()
270
927
  except UpstreamTaskError as upstream_exc:
271
- state = self.set_state(
928
+ state = await self.set_state(
272
929
  Pending(
273
930
  name="NotReady",
274
931
  message=str(upstream_exc),
@@ -291,7 +948,7 @@ class TaskRunEngine(Generic[P, R]):
291
948
  flow_run = flow_run_context.flow_run
292
949
  self.task_run.flow_run_run_count = flow_run.run_count
293
950
 
294
- state = self.set_state(new_state)
951
+ state = await self.set_state(new_state)
295
952
 
296
953
  # TODO: this is temporary until the API stops rejecting state transitions
297
954
  # and the client / transaction store becomes the source of truth
@@ -299,11 +956,10 @@ class TaskRunEngine(Generic[P, R]):
299
956
  # result reference that no longer exists
300
957
  if state.is_completed():
301
958
  try:
302
- state.result(retry_result_failure=False, _sync=True)
959
+ await state.result(retry_result_failure=False)
303
960
  except Exception:
304
- state = self.set_state(new_state, force=True)
961
+ state = await self.set_state(new_state, force=True)
305
962
 
306
- BACKOFF_MAX = 10
307
963
  backoff_count = 0
308
964
 
309
965
  # TODO: Could this listen for state change events instead of polling?
@@ -313,10 +969,10 @@ class TaskRunEngine(Generic[P, R]):
313
969
  interval = clamped_poisson_interval(
314
970
  average_interval=backoff_count, clamping_factor=0.3
315
971
  )
316
- time.sleep(interval)
317
- state = self.set_state(new_state)
972
+ await anyio.sleep(interval)
973
+ state = await self.set_state(new_state)
318
974
 
319
- def set_state(self, state: State, force: bool = False) -> State:
975
+ async def set_state(self, state: State, force: bool = False) -> State:
320
976
  last_state = self.state
321
977
  if not self.task_run:
322
978
  raise ValueError("Task run is not set")
@@ -332,9 +988,23 @@ class TaskRunEngine(Generic[P, R]):
332
988
  self.task_run.state_id = new_state.id
333
989
  self.task_run.state_type = new_state.type
334
990
  self.task_run.state_name = new_state.name
991
+
992
+ if new_state.is_final():
993
+ if (
994
+ isinstance(new_state.data, BaseResult)
995
+ and new_state.data.has_cached_object()
996
+ ):
997
+ # Avoid fetching the result unless it is cached, otherwise we defeat
998
+ # the purpose of disabling `cache_result_in_memory`
999
+ result = await new_state.result(raise_on_failure=False, fetch=True)
1000
+ else:
1001
+ result = new_state.data
1002
+
1003
+ link_state_to_result(new_state, result)
1004
+
335
1005
  else:
336
1006
  try:
337
- new_state = propose_state_sync(
1007
+ new_state = await propose_state(
338
1008
  self.client, state, task_run_id=self.task_run.id, force=force
339
1009
  )
340
1010
  except Pause as exc:
@@ -361,14 +1031,11 @@ class TaskRunEngine(Generic[P, R]):
361
1031
 
362
1032
  return new_state
363
1033
 
364
- def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
1034
+ async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
365
1035
  if self._return_value is not NotSet:
366
1036
  # if the return value is a BaseResult, we need to fetch it
367
1037
  if isinstance(self._return_value, BaseResult):
368
- _result = self._return_value.get()
369
- if inspect.isawaitable(_result):
370
- _result = run_coro_as_sync(_result)
371
- return _result
1038
+ return await self._return_value.get()
372
1039
 
373
1040
  # otherwise, return the value as is
374
1041
  return self._return_value
@@ -381,7 +1048,7 @@ class TaskRunEngine(Generic[P, R]):
381
1048
  # otherwise, return the exception
382
1049
  return self._raised
383
1050
 
384
- def handle_success(self, result: R, transaction: Transaction) -> R:
1051
+ async def handle_success(self, result: R, transaction: Transaction) -> R:
385
1052
  result_factory = getattr(TaskRunContext.get(), "result_factory", None)
386
1053
  if result_factory is None:
387
1054
  raise ValueError("Result factory is not set")
@@ -391,19 +1058,18 @@ class TaskRunEngine(Generic[P, R]):
391
1058
  else:
392
1059
  expiration = None
393
1060
 
394
- terminal_state = run_coro_as_sync(
395
- return_value_to_state(
396
- result,
397
- result_factory=result_factory,
398
- key=transaction.key,
399
- expiration=expiration,
400
- # defer persistence to transaction commit
401
- defer_persistence=True,
402
- )
1061
+ terminal_state = await return_value_to_state(
1062
+ result,
1063
+ result_factory=result_factory,
1064
+ key=transaction.key,
1065
+ expiration=expiration,
1066
+ # defer persistence to transaction commit
1067
+ defer_persistence=True,
403
1068
  )
404
1069
  transaction.stage(
405
1070
  terminal_state.data,
406
- on_rollback_hooks=[
1071
+ on_rollback_hooks=[self.handle_rollback]
1072
+ + [
407
1073
  _with_transaction_hook_logging(hook, "rollback", self.logger)
408
1074
  for hook in self.task.on_rollback_hooks
409
1075
  ],
@@ -416,18 +1082,18 @@ class TaskRunEngine(Generic[P, R]):
416
1082
  terminal_state.name = "Cached"
417
1083
 
418
1084
  self.record_terminal_state_timing(terminal_state)
419
- self.set_state(terminal_state)
1085
+ await self.set_state(terminal_state)
420
1086
  self._return_value = result
421
1087
  return result
422
1088
 
423
- def handle_retry(self, exc: Exception) -> bool:
1089
+ async def handle_retry(self, exc: Exception) -> bool:
424
1090
  """Handle any task run retries.
425
1091
 
426
1092
  - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
427
- - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
1093
+ - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
428
1094
  - If the task has no retries left, or the retry condition is not met, return False.
429
1095
  """
430
- if self.retries < self.task.retries and self.can_retry(exc):
1096
+ if self.retries < self.task.retries and await self.can_retry(exc):
431
1097
  if self.task.retry_delay_seconds:
432
1098
  delay = (
433
1099
  self.task.retry_delay_seconds[
@@ -453,7 +1119,7 @@ class TaskRunEngine(Generic[P, R]):
453
1119
  str(delay) + " second(s) from now" if delay else "immediately",
454
1120
  )
455
1121
 
456
- self.set_state(new_state, force=True)
1122
+ await self.set_state(new_state, force=True)
457
1123
  self.retries = self.retries + 1
458
1124
  return True
459
1125
  elif self.retries >= self.task.retries:
@@ -466,24 +1132,22 @@ class TaskRunEngine(Generic[P, R]):
466
1132
 
467
1133
  return False
468
1134
 
469
- def handle_exception(self, exc: Exception) -> None:
1135
+ async def handle_exception(self, exc: Exception) -> None:
470
1136
  # If the task fails, and we have retries left, set the task to retrying.
471
- if not self.handle_retry(exc):
1137
+ if not await self.handle_retry(exc):
472
1138
  # If the task has no retries left, or the retry condition is not met, set the task to failed.
473
1139
  context = TaskRunContext.get()
474
- state = run_coro_as_sync(
475
- exception_to_failed_state(
476
- exc,
477
- message="Task run encountered an exception",
478
- result_factory=getattr(context, "result_factory", None),
479
- )
1140
+ state = await exception_to_failed_state(
1141
+ exc,
1142
+ message="Task run encountered an exception",
1143
+ result_factory=getattr(context, "result_factory", None),
480
1144
  )
481
1145
  self.record_terminal_state_timing(state)
482
- self.set_state(state)
1146
+ await self.set_state(state)
483
1147
  self._raised = exc
484
1148
 
485
- def handle_timeout(self, exc: TimeoutError) -> None:
486
- if not self.handle_retry(exc):
1149
+ async def handle_timeout(self, exc: TimeoutError) -> None:
1150
+ if not await self.handle_retry(exc):
487
1151
  if isinstance(exc, TaskRunTimeoutError):
488
1152
  message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
489
1153
  else:
@@ -494,29 +1158,19 @@ class TaskRunEngine(Generic[P, R]):
494
1158
  message=message,
495
1159
  name="TimedOut",
496
1160
  )
497
- self.set_state(state)
1161
+ await self.set_state(state)
498
1162
  self._raised = exc
499
1163
 
500
- def handle_crash(self, exc: BaseException) -> None:
501
- state = run_coro_as_sync(exception_to_crashed_state(exc))
1164
+ async def handle_crash(self, exc: BaseException) -> None:
1165
+ state = await exception_to_crashed_state(exc)
502
1166
  self.logger.error(f"Crash detected! {state.message}")
503
1167
  self.logger.debug("Crash details:", exc_info=exc)
504
1168
  self.record_terminal_state_timing(state)
505
- self.set_state(state, force=True)
1169
+ await self.set_state(state, force=True)
506
1170
  self._raised = exc
507
1171
 
508
- def record_terminal_state_timing(self, state: State) -> None:
509
- if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
510
- if self.task_run.start_time and not self.task_run.end_time:
511
- self.task_run.end_time = state.timestamp
512
-
513
- if self.task_run.state.is_running():
514
- self.task_run.total_run_time += (
515
- state.timestamp - self.task_run.state.timestamp
516
- )
517
-
518
- @contextmanager
519
- def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
1172
+ @asynccontextmanager
1173
+ async def setup_run_context(self, client: Optional[PrefectClient] = None):
520
1174
  from prefect.utilities.engine import (
521
1175
  _resolve_custom_task_run_name,
522
1176
  should_log_prints,
@@ -528,7 +1182,7 @@ class TaskRunEngine(Generic[P, R]):
528
1182
  raise ValueError("Task run is not set")
529
1183
 
530
1184
  if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
531
- self.task_run = client.read_task_run(self.task_run.id)
1185
+ self.task_run = await client.read_task_run(self.task_run.id)
532
1186
  with ExitStack() as stack:
533
1187
  if log_prints := should_log_prints(self.task):
534
1188
  stack.enter_context(patch_print())
@@ -538,10 +1192,11 @@ class TaskRunEngine(Generic[P, R]):
538
1192
  log_prints=log_prints,
539
1193
  task_run=self.task_run,
540
1194
  parameters=self.parameters,
541
- result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore
1195
+ result_factory=await ResultFactory.from_task(self.task), # type: ignore
542
1196
  client=client,
543
1197
  )
544
1198
  )
1199
+ stack.enter_context(ConcurrencyContext())
545
1200
 
546
1201
  self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
547
1202
 
@@ -551,7 +1206,7 @@ class TaskRunEngine(Generic[P, R]):
551
1206
  task_run_name = _resolve_custom_task_run_name(
552
1207
  task=self.task, parameters=self.parameters
553
1208
  )
554
- self.client.set_task_run_name(
1209
+ await self.client.set_task_run_name(
555
1210
  task_run_id=self.task_run.id, name=task_run_name
556
1211
  )
557
1212
  self.logger.extra["task_run_name"] = task_run_name
@@ -562,19 +1217,19 @@ class TaskRunEngine(Generic[P, R]):
562
1217
  self._task_name_set = True
563
1218
  yield
564
1219
 
565
- @contextmanager
566
- def initialize_run(
1220
+ @asynccontextmanager
1221
+ async def initialize_run(
567
1222
  self,
568
1223
  task_run_id: Optional[UUID] = None,
569
1224
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
570
- ) -> Generator["TaskRunEngine", Any, Any]:
1225
+ ) -> AsyncGenerator["AsyncTaskRunEngine", Any]:
571
1226
  """
572
1227
  Enters a client context and creates a task run if needed.
573
1228
  """
574
1229
 
575
1230
  with hydrated_context(self.context):
576
- with ClientContext.get_or_create() as client_ctx:
577
- self._client = client_ctx.sync_client
1231
+ async with AsyncClientContext.get_or_create():
1232
+ self._client = get_client()
578
1233
  self._is_started = True
579
1234
  try:
580
1235
  if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
@@ -594,37 +1249,39 @@ class TaskRunEngine(Generic[P, R]):
594
1249
  self.task_run.name = task_run_name
595
1250
 
596
1251
  if not self.task_run:
597
- self.task_run = run_coro_as_sync(
598
- self.task.create_local_run(
599
- id=task_run_id,
600
- parameters=self.parameters,
601
- flow_run_context=FlowRunContext.get(),
602
- parent_task_run_context=TaskRunContext.get(),
603
- wait_for=self.wait_for,
604
- extra_task_inputs=dependencies,
605
- task_run_name=task_run_name,
606
- )
1252
+ self.task_run = await self.task.create_local_run(
1253
+ id=task_run_id,
1254
+ parameters=self.parameters,
1255
+ flow_run_context=FlowRunContext.get(),
1256
+ parent_task_run_context=TaskRunContext.get(),
1257
+ wait_for=self.wait_for,
1258
+ extra_task_inputs=dependencies,
1259
+ task_run_name=task_run_name,
1260
+ )
1261
+ # Emit an event to capture that the task run was in the `PENDING` state.
1262
+ self._last_event = emit_task_run_state_change_event(
1263
+ task_run=self.task_run,
1264
+ initial_state=None,
1265
+ validated_state=self.task_run.state,
607
1266
  )
608
1267
  else:
609
1268
  if not self.task_run:
610
- self.task_run = run_coro_as_sync(
611
- self.task.create_run(
612
- id=task_run_id,
613
- parameters=self.parameters,
614
- flow_run_context=FlowRunContext.get(),
615
- parent_task_run_context=TaskRunContext.get(),
616
- wait_for=self.wait_for,
617
- extra_task_inputs=dependencies,
618
- )
1269
+ self.task_run = await self.task.create_run(
1270
+ id=task_run_id,
1271
+ parameters=self.parameters,
1272
+ flow_run_context=FlowRunContext.get(),
1273
+ parent_task_run_context=TaskRunContext.get(),
1274
+ wait_for=self.wait_for,
1275
+ extra_task_inputs=dependencies,
1276
+ )
1277
+ # Emit an event to capture that the task run was in the `PENDING` state.
1278
+ self._last_event = emit_task_run_state_change_event(
1279
+ task_run=self.task_run,
1280
+ initial_state=None,
1281
+ validated_state=self.task_run.state,
619
1282
  )
620
- # Emit an event to capture that the task run was in the `PENDING` state.
621
- self._last_event = emit_task_run_state_change_event(
622
- task_run=self.task_run,
623
- initial_state=None,
624
- validated_state=self.task_run.state,
625
- )
626
1283
 
627
- with self.setup_run_context():
1284
+ async with self.setup_run_context():
628
1285
  # setup_run_context might update the task run name, so log creation here
629
1286
  self.logger.info(
630
1287
  f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
@@ -633,7 +1290,7 @@ class TaskRunEngine(Generic[P, R]):
633
1290
 
634
1291
  except TerminationSignal as exc:
635
1292
  # TerminationSignals are caught and handled as crashes
636
- self.handle_crash(exc)
1293
+ await self.handle_crash(exc)
637
1294
  raise exc
638
1295
 
639
1296
  except Exception:
@@ -649,60 +1306,19 @@ class TaskRunEngine(Generic[P, R]):
649
1306
  raise
650
1307
  except BaseException as exc:
651
1308
  # BaseExceptions are caught and handled as crashes
652
- self.handle_crash(exc)
1309
+ await self.handle_crash(exc)
653
1310
  raise
654
1311
  finally:
655
- # If debugging, use the more complete `repr` than the usual `str` description
656
- display_state = (
657
- repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
658
- )
659
- level = logging.INFO if self.state.is_completed() else logging.ERROR
660
- msg = f"Finished in state {display_state}"
661
- if self.state.is_pending():
662
- msg += (
663
- "\nPlease wait for all submitted tasks to complete"
664
- " before exiting your flow by calling `.wait()` on the "
665
- "`PrefectFuture` returned from your `.submit()` calls."
666
- )
667
- msg += dedent(
668
- """
669
-
670
- Example:
671
-
672
- from prefect import flow, task
673
-
674
- @task
675
- def say_hello(name):
676
- print f"Hello, {name}!"
677
-
678
- @flow
679
- def example_flow():
680
- future = say_hello.submit(name="Marvin)
681
- future.wait()
682
-
683
- example_flow()
684
- """
685
- )
686
- self.logger.log(
687
- level=level,
688
- msg=msg,
689
- )
690
-
1312
+ self.log_finished_message()
691
1313
  self._is_started = False
692
1314
  self._client = None
693
1315
 
694
- def is_running(self) -> bool:
695
- """Whether or not the engine is currently running a task."""
696
- if (task_run := getattr(self, "task_run", None)) is None:
697
- return False
698
- return task_run.state.is_running() or task_run.state.is_scheduled()
699
-
700
1316
  async def wait_until_ready(self):
701
1317
  """Waits until the scheduled time (if its the future), then enters Running."""
702
1318
  if scheduled_time := self.state.state_details.scheduled_time:
703
1319
  sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
704
1320
  await anyio.sleep(sleep_time if sleep_time > 0 else 0)
705
- self.set_state(
1321
+ await self.set_state(
706
1322
  Retrying() if self.state.name == "AwaitingRetry" else Running(),
707
1323
  force=True,
708
1324
  )
@@ -713,21 +1329,23 @@ class TaskRunEngine(Generic[P, R]):
713
1329
  #
714
1330
  # --------------------------
715
1331
 
716
- @contextmanager
717
- def start(
1332
+ @asynccontextmanager
1333
+ async def start(
718
1334
  self,
719
1335
  task_run_id: Optional[UUID] = None,
720
1336
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
721
- ) -> Generator[None, None, None]:
722
- with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
723
- self.begin_run()
1337
+ ) -> AsyncGenerator[None, None]:
1338
+ async with self.initialize_run(
1339
+ task_run_id=task_run_id, dependencies=dependencies
1340
+ ):
1341
+ await self.begin_run()
724
1342
  try:
725
1343
  yield
726
1344
  finally:
727
- self.call_hooks()
1345
+ await self.call_hooks()
728
1346
 
729
- @contextmanager
730
- def transaction_context(self) -> Generator[Transaction, None, None]:
1347
+ @asynccontextmanager
1348
+ async def transaction_context(self) -> AsyncGenerator[Transaction, None]:
731
1349
  result_factory = getattr(TaskRunContext.get(), "result_factory", None)
732
1350
 
733
1351
  # refresh cache setting is now repurposes as overwrite transaction record
@@ -744,13 +1362,12 @@ class TaskRunEngine(Generic[P, R]):
744
1362
  ) as txn:
745
1363
  yield txn
746
1364
 
747
- @contextmanager
748
- def run_context(self):
749
- timeout_context = timeout_async if self.task.isasync else timeout
1365
+ @asynccontextmanager
1366
+ async def run_context(self):
750
1367
  # reenter the run context to ensure it is up to date for every run
751
- with self.setup_run_context():
1368
+ async with self.setup_run_context():
752
1369
  try:
753
- with timeout_context(
1370
+ with timeout_async(
754
1371
  seconds=self.task.timeout_seconds,
755
1372
  timeout_exc_type=TaskRunTimeoutError,
756
1373
  ):
@@ -762,11 +1379,11 @@ class TaskRunEngine(Generic[P, R]):
762
1379
 
763
1380
  yield self
764
1381
  except TimeoutError as exc:
765
- self.handle_timeout(exc)
1382
+ await self.handle_timeout(exc)
766
1383
  except Exception as exc:
767
- self.handle_exception(exc)
1384
+ await self.handle_exception(exc)
768
1385
 
769
- def call_task_fn(
1386
+ async def call_task_fn(
770
1387
  self, transaction: Transaction
771
1388
  ) -> Union[R, Coroutine[Any, Any, R]]:
772
1389
  """
@@ -774,24 +1391,23 @@ class TaskRunEngine(Generic[P, R]):
774
1391
  task is async.
775
1392
  """
776
1393
  parameters = self.parameters or {}
777
- if self.task.isasync:
778
-
779
- async def _call_task_fn():
780
- if transaction.is_committed():
781
- result = transaction.read()
782
- else:
783
- result = await call_with_parameters(self.task.fn, parameters)
784
- self.handle_success(result, transaction=transaction)
785
- return result
786
-
787
- return _call_task_fn()
1394
+ if transaction.is_committed():
1395
+ result = transaction.read()
788
1396
  else:
789
- if transaction.is_committed():
790
- result = transaction.read()
1397
+ if (
1398
+ PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value()
1399
+ and self.task.tags
1400
+ ):
1401
+ # Acquire a concurrency slot for each tag, but only if a limit
1402
+ # matching the tag already exists.
1403
+ async with aconcurrency(
1404
+ list(self.task.tags), occupy=1, create_if_missing=False
1405
+ ):
1406
+ result = await call_with_parameters(self.task.fn, parameters)
791
1407
  else:
792
- result = call_with_parameters(self.task.fn, parameters)
793
- self.handle_success(result, transaction=transaction)
794
- return result
1408
+ result = await call_with_parameters(self.task.fn, parameters)
1409
+ await self.handle_success(result, transaction=transaction)
1410
+ return result
795
1411
 
796
1412
 
797
1413
  def run_task_sync(
@@ -804,7 +1420,7 @@ def run_task_sync(
804
1420
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
805
1421
  context: Optional[Dict[str, Any]] = None,
806
1422
  ) -> Union[R, State, None]:
807
- engine = TaskRunEngine[P, R](
1423
+ engine = SyncTaskRunEngine[P, R](
808
1424
  task=task,
809
1425
  parameters=parameters,
810
1426
  task_run=task_run,
@@ -831,7 +1447,7 @@ async def run_task_async(
831
1447
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
832
1448
  context: Optional[Dict[str, Any]] = None,
833
1449
  ) -> Union[R, State, None]:
834
- engine = TaskRunEngine[P, R](
1450
+ engine = AsyncTaskRunEngine[P, R](
835
1451
  task=task,
836
1452
  parameters=parameters,
837
1453
  task_run=task_run,
@@ -839,13 +1455,13 @@ async def run_task_async(
839
1455
  context=context,
840
1456
  )
841
1457
 
842
- with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1458
+ async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
843
1459
  while engine.is_running():
844
1460
  await engine.wait_until_ready()
845
- with engine.run_context(), engine.transaction_context() as txn:
1461
+ async with engine.run_context(), engine.transaction_context() as txn:
846
1462
  await engine.call_task_fn(txn)
847
1463
 
848
- return engine.state if return_type == "state" else engine.result()
1464
+ return engine.state if return_type == "state" else await engine.result()
849
1465
 
850
1466
 
851
1467
  def run_generator_task_sync(
@@ -861,7 +1477,7 @@ def run_generator_task_sync(
861
1477
  if return_type != "result":
862
1478
  raise ValueError("The return_type for a generator task must be 'result'")
863
1479
 
864
- engine = TaskRunEngine[P, R](
1480
+ engine = SyncTaskRunEngine[P, R](
865
1481
  task=task,
866
1482
  parameters=parameters,
867
1483
  task_run=task_run,
@@ -915,7 +1531,7 @@ async def run_generator_task_async(
915
1531
  ) -> AsyncGenerator[R, None]:
916
1532
  if return_type != "result":
917
1533
  raise ValueError("The return_type for a generator task must be 'result'")
918
- engine = TaskRunEngine[P, R](
1534
+ engine = AsyncTaskRunEngine[P, R](
919
1535
  task=task,
920
1536
  parameters=parameters,
921
1537
  task_run=task_run,
@@ -923,10 +1539,10 @@ async def run_generator_task_async(
923
1539
  context=context,
924
1540
  )
925
1541
 
926
- with engine.start(task_run_id=task_run_id, dependencies=dependencies):
1542
+ async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
927
1543
  while engine.is_running():
928
1544
  await engine.wait_until_ready()
929
- with engine.run_context(), engine.transaction_context() as txn:
1545
+ async with engine.run_context(), engine.transaction_context() as txn:
930
1546
  # TODO: generators should default to commit_mode=OFF
931
1547
  # because they are dynamic by definition
932
1548
  # for now we just prevent this branch explicitly
@@ -950,13 +1566,13 @@ async def run_generator_task_async(
950
1566
  link_state_to_result(engine.state, gen_result)
951
1567
  yield gen_result
952
1568
  except (StopAsyncIteration, GeneratorExit) as exc:
953
- engine.handle_success(None, transaction=txn)
1569
+ await engine.handle_success(None, transaction=txn)
954
1570
  if isinstance(exc, GeneratorExit):
955
1571
  gen.throw(exc)
956
1572
 
957
1573
  # async generators can't return, but we can raise failures here
958
1574
  if engine.state.is_failed():
959
- engine.result()
1575
+ await engine.result()
960
1576
 
961
1577
 
962
1578
  def run_task(