prefect-client 3.0.0rc1__py3-none-any.whl → 3.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. prefect/_internal/compatibility/migration.py +124 -0
  2. prefect/_internal/concurrency/__init__.py +2 -2
  3. prefect/_internal/concurrency/primitives.py +1 -0
  4. prefect/_internal/pydantic/annotations/pendulum.py +2 -2
  5. prefect/_internal/pytz.py +1 -1
  6. prefect/blocks/core.py +1 -1
  7. prefect/blocks/redis.py +168 -0
  8. prefect/client/orchestration.py +113 -23
  9. prefect/client/schemas/actions.py +1 -1
  10. prefect/client/schemas/filters.py +6 -0
  11. prefect/client/schemas/objects.py +22 -11
  12. prefect/client/subscriptions.py +3 -2
  13. prefect/concurrency/asyncio.py +1 -1
  14. prefect/concurrency/services.py +1 -1
  15. prefect/context.py +1 -27
  16. prefect/deployments/__init__.py +3 -0
  17. prefect/deployments/base.py +11 -3
  18. prefect/deployments/deployments.py +3 -0
  19. prefect/deployments/steps/pull.py +1 -0
  20. prefect/deployments/steps/utility.py +2 -1
  21. prefect/engine.py +3 -0
  22. prefect/events/cli/automations.py +1 -1
  23. prefect/events/clients.py +7 -1
  24. prefect/events/schemas/events.py +2 -0
  25. prefect/exceptions.py +9 -0
  26. prefect/filesystems.py +22 -11
  27. prefect/flow_engine.py +118 -156
  28. prefect/flow_runs.py +2 -2
  29. prefect/flows.py +91 -35
  30. prefect/futures.py +44 -43
  31. prefect/infrastructure/provisioners/container_instance.py +1 -0
  32. prefect/infrastructure/provisioners/ecs.py +2 -2
  33. prefect/input/__init__.py +4 -0
  34. prefect/input/run_input.py +4 -2
  35. prefect/logging/formatters.py +2 -2
  36. prefect/logging/handlers.py +2 -2
  37. prefect/logging/loggers.py +1 -1
  38. prefect/plugins.py +1 -0
  39. prefect/records/cache_policies.py +179 -0
  40. prefect/records/result_store.py +10 -3
  41. prefect/results.py +27 -55
  42. prefect/runner/runner.py +1 -1
  43. prefect/runner/server.py +1 -1
  44. prefect/runtime/__init__.py +1 -0
  45. prefect/runtime/deployment.py +1 -0
  46. prefect/runtime/flow_run.py +1 -0
  47. prefect/runtime/task_run.py +1 -0
  48. prefect/settings.py +21 -5
  49. prefect/states.py +17 -4
  50. prefect/task_engine.py +337 -209
  51. prefect/task_runners.py +15 -5
  52. prefect/task_runs.py +203 -0
  53. prefect/{task_server.py → task_worker.py} +66 -36
  54. prefect/tasks.py +180 -77
  55. prefect/transactions.py +92 -16
  56. prefect/types/__init__.py +1 -1
  57. prefect/utilities/asyncutils.py +3 -3
  58. prefect/utilities/callables.py +90 -7
  59. prefect/utilities/dockerutils.py +5 -3
  60. prefect/utilities/engine.py +11 -0
  61. prefect/utilities/filesystem.py +4 -5
  62. prefect/utilities/importtools.py +34 -5
  63. prefect/utilities/services.py +2 -2
  64. prefect/utilities/urls.py +195 -0
  65. prefect/utilities/visualization.py +1 -0
  66. prefect/variables.py +19 -10
  67. prefect/workers/base.py +46 -1
  68. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/METADATA +3 -2
  69. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/RECORD +72 -66
  70. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/LICENSE +0 -0
  71. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/WHEEL +0 -0
  72. {prefect_client-3.0.0rc1.dist-info → prefect_client-3.0.0rc3.dist-info}/top_level.txt +0 -0
prefect/task_engine.py CHANGED
@@ -3,8 +3,10 @@ import logging
3
3
  import time
4
4
  from contextlib import ExitStack, contextmanager
5
5
  from dataclasses import dataclass, field
6
+ from textwrap import dedent
6
7
  from typing import (
7
8
  Any,
9
+ AsyncGenerator,
8
10
  Callable,
9
11
  Coroutine,
10
12
  Dict,
@@ -13,17 +15,18 @@ from typing import (
13
15
  Iterable,
14
16
  Literal,
15
17
  Optional,
18
+ Sequence,
16
19
  Set,
17
20
  TypeVar,
18
21
  Union,
19
22
  )
20
23
  from uuid import UUID
21
24
 
25
+ import anyio
22
26
  import pendulum
23
27
  from typing_extensions import ParamSpec
24
28
 
25
29
  from prefect import Task
26
- from prefect._internal.concurrency.api import create_call, from_sync
27
30
  from prefect.client.orchestration import SyncPrefectClient
28
31
  from prefect.client.schemas import TaskRun
29
32
  from prefect.client.schemas.objects import State, TaskRunInput
@@ -41,7 +44,6 @@ from prefect.exceptions import (
41
44
  UpstreamTaskError,
42
45
  )
43
46
  from prefect.futures import PrefectFuture
44
- from prefect.logging.handlers import APILogHandler
45
47
  from prefect.logging.loggers import get_logger, patch_print, task_run_logger
46
48
  from prefect.records.result_store import ResultFactoryStore
47
49
  from prefect.results import ResultFactory, _format_user_supplied_storage_key
@@ -50,23 +52,24 @@ from prefect.settings import (
50
52
  PREFECT_TASKS_REFRESH_CACHE,
51
53
  )
52
54
  from prefect.states import (
55
+ AwaitingRetry,
53
56
  Failed,
54
57
  Paused,
55
58
  Pending,
56
59
  Retrying,
57
60
  Running,
58
- StateDetails,
59
61
  exception_to_crashed_state,
60
62
  exception_to_failed_state,
61
63
  return_value_to_state,
62
64
  )
63
65
  from prefect.transactions import Transaction, transaction
64
66
  from prefect.utilities.asyncutils import run_coro_as_sync
65
- from prefect.utilities.callables import parameters_to_args_kwargs
67
+ from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs
66
68
  from prefect.utilities.collections import visit_collection
67
69
  from prefect.utilities.engine import (
68
70
  _get_hook_name,
69
71
  emit_task_run_state_change_event,
72
+ link_state_to_result,
70
73
  propose_state_sync,
71
74
  resolve_to_final_result,
72
75
  )
@@ -133,101 +136,54 @@ class TaskRunEngine(Generic[P, R]):
133
136
  )
134
137
  return False
135
138
 
136
- def get_hooks(self, state: State, as_async: bool = False) -> Iterable[Callable]:
139
+ def call_hooks(self, state: State = None) -> Iterable[Callable]:
140
+ if state is None:
141
+ state = self.state
137
142
  task = self.task
138
143
  task_run = self.task_run
139
144
 
140
145
  if not task_run:
141
146
  raise ValueError("Task run is not set")
142
147
 
143
- hooks = None
144
148
  if state.is_failed() and task.on_failure_hooks:
145
149
  hooks = task.on_failure_hooks
146
150
  elif state.is_completed() and task.on_completion_hooks:
147
151
  hooks = task.on_completion_hooks
152
+ else:
153
+ hooks = None
148
154
 
149
155
  for hook in hooks or []:
150
156
  hook_name = _get_hook_name(hook)
151
157
 
152
- @contextmanager
153
- def hook_context():
154
- try:
155
- self.logger.info(
156
- f"Running hook {hook_name!r} in response to entering state"
157
- f" {state.name!r}"
158
- )
159
- yield
160
- except Exception:
161
- self.logger.error(
162
- f"An error was encountered while running hook {hook_name!r}",
163
- exc_info=True,
164
- )
165
- else:
166
- self.logger.info(
167
- f"Hook {hook_name!r} finished running successfully"
168
- )
169
-
170
- if as_async:
171
-
172
- async def _hook_fn():
173
- with hook_context():
174
- result = hook(task, task_run, state)
175
- if inspect.isawaitable(result):
176
- await result
177
-
158
+ try:
159
+ self.logger.info(
160
+ f"Running hook {hook_name!r} in response to entering state"
161
+ f" {state.name!r}"
162
+ )
163
+ result = hook(task, task_run, state)
164
+ if inspect.isawaitable(result):
165
+ run_coro_as_sync(result)
166
+ except Exception:
167
+ self.logger.error(
168
+ f"An error was encountered while running hook {hook_name!r}",
169
+ exc_info=True,
170
+ )
178
171
  else:
179
-
180
- def _hook_fn():
181
- with hook_context():
182
- result = hook(task, task_run, state)
183
- if inspect.isawaitable(result):
184
- run_coro_as_sync(result)
185
-
186
- yield _hook_fn
172
+ self.logger.info(f"Hook {hook_name!r} finished running successfully")
187
173
 
188
174
  def compute_transaction_key(self) -> str:
189
- if self.task.result_storage_key is not None:
175
+ key = None
176
+ if self.task.cache_policy:
177
+ task_run_context = TaskRunContext.get()
178
+ key = self.task.cache_policy.compute_key(
179
+ task_ctx=task_run_context,
180
+ inputs=self.parameters,
181
+ flow_parameters=None,
182
+ )
183
+ elif self.task.result_storage_key is not None:
190
184
  key = _format_user_supplied_storage_key(self.task.result_storage_key)
191
- else:
192
- key = str(self.task_run.id)
193
185
  return key
194
186
 
195
- def _compute_state_details(
196
- self, include_cache_expiration: bool = False
197
- ) -> StateDetails:
198
- task_run_context = TaskRunContext.get()
199
- ## setup cache metadata
200
- cache_key = (
201
- self.task.cache_key_fn(
202
- task_run_context,
203
- self.parameters or {},
204
- )
205
- if self.task.cache_key_fn
206
- else None
207
- )
208
- # Ignore the cached results for a cache key, default = false
209
- # Setting on task level overrules the Prefect setting (env var)
210
- refresh_cache = (
211
- self.task.refresh_cache
212
- if self.task.refresh_cache is not None
213
- else PREFECT_TASKS_REFRESH_CACHE.value()
214
- )
215
-
216
- if include_cache_expiration:
217
- cache_expiration = (
218
- (pendulum.now("utc") + self.task.cache_expiration)
219
- if self.task.cache_expiration
220
- else None
221
- )
222
- else:
223
- cache_expiration = None
224
-
225
- return StateDetails(
226
- cache_key=cache_key,
227
- refresh_cache=refresh_cache,
228
- cache_expiration=cache_expiration,
229
- )
230
-
231
187
  def _resolve_parameters(self):
232
188
  if not self.parameters:
233
189
  return {}
@@ -264,7 +220,7 @@ class TaskRunEngine(Generic[P, R]):
264
220
  return_data=False,
265
221
  max_depth=-1,
266
222
  remove_annotations=True,
267
- context={},
223
+ context={"current_task_run": self.task_run, "current_task": self.task},
268
224
  )
269
225
 
270
226
  def begin_run(self):
@@ -283,8 +239,7 @@ class TaskRunEngine(Generic[P, R]):
283
239
  )
284
240
  return
285
241
 
286
- state_details = self._compute_state_details()
287
- new_state = Running(state_details=state_details)
242
+ new_state = Running()
288
243
  state = self.set_state(new_state)
289
244
 
290
245
  BACKOFF_MAX = 10
@@ -344,17 +299,17 @@ class TaskRunEngine(Generic[P, R]):
344
299
  if result_factory is None:
345
300
  raise ValueError("Result factory is not set")
346
301
 
347
- # dont put this inside function, else the transaction could get serialized
348
- key = transaction.key
349
-
350
- def key_fn():
351
- return key
302
+ if self.task.cache_expiration is not None:
303
+ expiration = pendulum.now("utc") + self.task.cache_expiration
304
+ else:
305
+ expiration = None
352
306
 
353
- result_factory.storage_key_fn = key_fn
354
307
  terminal_state = run_coro_as_sync(
355
308
  return_value_to_state(
356
309
  result,
357
310
  result_factory=result_factory,
311
+ key=transaction.key,
312
+ expiration=expiration,
358
313
  )
359
314
  )
360
315
  transaction.stage(
@@ -362,22 +317,49 @@ class TaskRunEngine(Generic[P, R]):
362
317
  on_rollback_hooks=self.task.on_rollback_hooks,
363
318
  on_commit_hooks=self.task.on_commit_hooks,
364
319
  )
365
- terminal_state.state_details = self._compute_state_details(
366
- include_cache_expiration=True
367
- )
320
+ if transaction.is_committed():
321
+ terminal_state.name = "Cached"
368
322
  self.set_state(terminal_state)
369
323
  return result
370
324
 
371
325
  def handle_retry(self, exc: Exception) -> bool:
372
- """
373
- If the task has retries left, and the retry condition is met, set the task to retrying.
326
+ """Handle any task run retries.
327
+
328
+ - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
329
+ - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
374
330
  - If the task has no retries left, or the retry condition is not met, return False.
375
- - If the task has retries left, and the retry condition is met, return True.
376
331
  """
377
332
  if self.retries < self.task.retries and self.can_retry:
378
- self.set_state(Retrying(), force=True)
333
+ if self.task.retry_delay_seconds:
334
+ delay = (
335
+ self.task.retry_delay_seconds[
336
+ min(self.retries, len(self.task.retry_delay_seconds) - 1)
337
+ ] # repeat final delay value if attempts exceed specified delays
338
+ if isinstance(self.task.retry_delay_seconds, Sequence)
339
+ else self.task.retry_delay_seconds
340
+ )
341
+ new_state = AwaitingRetry(
342
+ scheduled_time=pendulum.now("utc").add(seconds=delay)
343
+ )
344
+ else:
345
+ delay = None
346
+ new_state = Retrying()
347
+
348
+ self.logger.info(
349
+ f"Task run failed with exception {exc!r} - "
350
+ f"Retry {self.retries + 1}/{self.task.retries} will start "
351
+ f"{str(delay) + ' second(s) from now' if delay else 'immediately'}"
352
+ )
353
+
354
+ self.set_state(new_state, force=True)
379
355
  self.retries = self.retries + 1
380
356
  return True
357
+ elif self.retries >= self.task.retries:
358
+ self.logger.error(
359
+ f"Task run failed with exception {exc!r} - Retries are exhausted"
360
+ )
361
+ return False
362
+
381
363
  return False
382
364
 
383
365
  def handle_exception(self, exc: Exception) -> None:
@@ -414,7 +396,7 @@ class TaskRunEngine(Generic[P, R]):
414
396
  self.set_state(state, force=True)
415
397
 
416
398
  @contextmanager
417
- def enter_run_context(self, client: Optional[SyncPrefectClient] = None):
399
+ def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
418
400
  from prefect.utilities.engine import (
419
401
  _resolve_custom_task_run_name,
420
402
  should_log_prints,
@@ -461,7 +443,7 @@ class TaskRunEngine(Generic[P, R]):
461
443
  yield
462
444
 
463
445
  @contextmanager
464
- def start(
446
+ def initialize_run(
465
447
  self,
466
448
  task_run_id: Optional[UUID] = None,
467
449
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
@@ -495,12 +477,19 @@ class TaskRunEngine(Generic[P, R]):
495
477
  validated_state=self.task_run.state,
496
478
  )
497
479
 
498
- yield self
480
+ with self.setup_run_context():
481
+ yield self
482
+
499
483
  except Exception:
500
484
  # regular exceptions are caught and re-raised to the user
501
485
  raise
502
- except (Pause, Abort):
486
+ except (Pause, Abort) as exc:
503
487
  # Do not capture internal signals as crashes
488
+ if isinstance(exc, Abort):
489
+ self.logger.error("Task run was aborted: %s", exc)
490
+ raise
491
+ except GeneratorExit:
492
+ # Do not capture generator exits as crashes
504
493
  raise
505
494
  except BaseException as exc:
506
495
  # BaseExceptions are caught and handled as crashes
@@ -511,26 +500,133 @@ class TaskRunEngine(Generic[P, R]):
511
500
  display_state = (
512
501
  repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
513
502
  )
514
- self.logger.log(
515
- level=(
516
- logging.INFO if self.state.is_completed() else logging.ERROR
517
- ),
518
- msg=f"Finished in state {display_state}",
519
- )
503
+ level = logging.INFO if self.state.is_completed() else logging.ERROR
504
+ msg = f"Finished in state {display_state}"
505
+ if self.state.is_pending():
506
+ msg += (
507
+ "\nPlease wait for all submitted tasks to complete"
508
+ " before exiting your flow by calling `.wait()` on the "
509
+ "`PrefectFuture` returned from your `.submit()` calls."
510
+ )
511
+ msg += dedent(
512
+ """
513
+
514
+ Example:
515
+
516
+ from prefect import flow, task
520
517
 
521
- # flush all logs if this is not a "top" level run
522
- if not (FlowRunContext.get() or TaskRunContext.get()):
523
- from_sync.call_soon_in_loop_thread(
524
- create_call(APILogHandler.aflush)
518
+ @task
519
+ def say_hello(name):
520
+ print f"Hello, {name}!"
521
+
522
+ @flow
523
+ def example_flow():
524
+ say_hello.submit(name="Marvin)
525
+ say_hello.wait()
526
+
527
+ example_flow()
528
+ """
525
529
  )
530
+ self.logger.log(
531
+ level=level,
532
+ msg=msg,
533
+ )
526
534
 
527
535
  self._is_started = False
528
536
  self._client = None
529
537
 
530
538
  def is_running(self) -> bool:
531
- if getattr(self, "task_run", None) is None:
539
+ """Whether or not the engine is currently running a task."""
540
+ if (task_run := getattr(self, "task_run", None)) is None:
532
541
  return False
533
- return getattr(self, "task_run").state.is_running()
542
+ return task_run.state.is_running() or task_run.state.is_scheduled()
543
+
544
+ async def wait_until_ready(self):
545
+ """Waits until the scheduled time (if its the future), then enters Running."""
546
+ if scheduled_time := self.state.state_details.scheduled_time:
547
+ sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
548
+ await anyio.sleep(sleep_time if sleep_time > 0 else 0)
549
+ self.set_state(
550
+ Retrying() if self.state.name == "AwaitingRetry" else Running(),
551
+ force=True,
552
+ )
553
+
554
+ # --------------------------
555
+ #
556
+ # The following methods compose the main task run loop
557
+ #
558
+ # --------------------------
559
+
560
+ @contextmanager
561
+ def start(
562
+ self,
563
+ task_run_id: Optional[UUID] = None,
564
+ dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
565
+ ) -> Generator[None, None, None]:
566
+ with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
567
+ self.begin_run()
568
+ try:
569
+ yield
570
+ finally:
571
+ self.call_hooks()
572
+
573
+ @contextmanager
574
+ def transaction_context(self) -> Generator[Transaction, None, None]:
575
+ result_factory = getattr(TaskRunContext.get(), "result_factory", None)
576
+
577
+ # refresh cache setting is now repurposes as overwrite transaction record
578
+ overwrite = (
579
+ self.task.refresh_cache
580
+ if self.task.refresh_cache is not None
581
+ else PREFECT_TASKS_REFRESH_CACHE.value()
582
+ )
583
+ with transaction(
584
+ key=self.compute_transaction_key(),
585
+ store=ResultFactoryStore(result_factory=result_factory),
586
+ overwrite=overwrite,
587
+ ) as txn:
588
+ yield txn
589
+
590
+ @contextmanager
591
+ def run_context(self):
592
+ timeout_context = timeout_async if self.task.isasync else timeout
593
+ # reenter the run context to ensure it is up to date for every run
594
+ with self.setup_run_context():
595
+ try:
596
+ with timeout_context(seconds=self.task.timeout_seconds):
597
+ self.logger.debug(
598
+ f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
599
+ )
600
+ yield self
601
+ except TimeoutError as exc:
602
+ self.handle_timeout(exc)
603
+ except Exception as exc:
604
+ self.handle_exception(exc)
605
+
606
+ def call_task_fn(
607
+ self, transaction: Transaction
608
+ ) -> Union[R, Coroutine[Any, Any, R]]:
609
+ """
610
+ Convenience method to call the task function. Returns a coroutine if the
611
+ task is async.
612
+ """
613
+ parameters = self.parameters or {}
614
+ if self.task.isasync:
615
+
616
+ async def _call_task_fn():
617
+ if transaction.is_committed():
618
+ result = transaction.read()
619
+ else:
620
+ result = await call_with_parameters(self.task.fn, parameters)
621
+ self.handle_success(result, transaction=transaction)
622
+
623
+ return _call_task_fn()
624
+ else:
625
+ if transaction.is_committed():
626
+ result = transaction.read()
627
+ else:
628
+ result = call_with_parameters(self.task.fn, parameters)
629
+ self.handle_success(result, transaction=transaction)
534
630
 
535
631
 
536
632
  def run_task_sync(
@@ -550,56 +646,18 @@ def run_task_sync(
550
646
  wait_for=wait_for,
551
647
  context=context,
552
648
  )
553
- # This is a context manager that keeps track of the run of the task run.
554
- with engine.start(task_run_id=task_run_id, dependencies=dependencies) as run:
555
- with run.enter_run_context():
556
- run.begin_run()
557
- while run.is_running():
558
- # enter run context on each loop iteration to ensure the context
559
- # contains the latest task run metadata
560
- with run.enter_run_context():
561
- try:
562
- # This is where the task is actually run.
563
- with timeout(seconds=run.task.timeout_seconds):
564
- call_args, call_kwargs = parameters_to_args_kwargs(
565
- task.fn, run.parameters or {}
566
- )
567
- run.logger.debug(
568
- f"Executing task {task.name!r} for task run {run.task_run.name!r}..."
569
- )
570
- result_factory = getattr(
571
- TaskRunContext.get(), "result_factory", None
572
- )
573
- with transaction(
574
- key=run.compute_transaction_key(),
575
- store=ResultFactoryStore(result_factory=result_factory),
576
- ) as txn:
577
- if txn.is_committed():
578
- result = txn.read()
579
- else:
580
- result = task.fn(*call_args, **call_kwargs) # type: ignore
581
-
582
- # If the task run is successful, finalize it.
583
- # do this within the transaction lifecycle
584
- # in order to get the proper result serialization
585
- run.handle_success(result, transaction=txn)
586
-
587
- except TimeoutError as exc:
588
- run.handle_timeout(exc)
589
- except Exception as exc:
590
- run.handle_exception(exc)
591
-
592
- if run.state.is_final():
593
- for hook in run.get_hooks(run.state):
594
- hook()
595
-
596
- if return_type == "state":
597
- return run.state
598
- return run.result()
649
+
650
+ with engine.start(task_run_id=task_run_id, dependencies=dependencies):
651
+ while engine.is_running():
652
+ run_coro_as_sync(engine.wait_until_ready())
653
+ with engine.run_context(), engine.transaction_context() as txn:
654
+ engine.call_task_fn(txn)
655
+
656
+ return engine.state if return_type == "state" else engine.result()
599
657
 
600
658
 
601
659
  async def run_task_async(
602
- task: Task[P, Coroutine[Any, Any, R]],
660
+ task: Task[P, R],
603
661
  task_run_id: Optional[UUID] = None,
604
662
  task_run: Optional[TaskRun] = None,
605
663
  parameters: Optional[Dict[str, Any]] = None,
@@ -608,11 +666,35 @@ async def run_task_async(
608
666
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
609
667
  context: Optional[Dict[str, Any]] = None,
610
668
  ) -> Union[R, State, None]:
611
- """
612
- Runs a task against the API.
669
+ engine = TaskRunEngine[P, R](
670
+ task=task,
671
+ parameters=parameters,
672
+ task_run=task_run,
673
+ wait_for=wait_for,
674
+ context=context,
675
+ )
613
676
 
614
- We will most likely want to use this logic as a wrapper and return a coroutine for type inference.
615
- """
677
+ with engine.start(task_run_id=task_run_id, dependencies=dependencies):
678
+ while engine.is_running():
679
+ await engine.wait_until_ready()
680
+ with engine.run_context(), engine.transaction_context() as txn:
681
+ await engine.call_task_fn(txn)
682
+
683
+ return engine.state if return_type == "state" else engine.result()
684
+
685
+
686
+ def run_generator_task_sync(
687
+ task: Task[P, R],
688
+ task_run_id: Optional[UUID] = None,
689
+ task_run: Optional[TaskRun] = None,
690
+ parameters: Optional[Dict[str, Any]] = None,
691
+ wait_for: Optional[Iterable[PrefectFuture]] = None,
692
+ return_type: Literal["state", "result"] = "result",
693
+ dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
694
+ context: Optional[Dict[str, Any]] = None,
695
+ ) -> Generator[R, None, None]:
696
+ if return_type != "result":
697
+ raise ValueError("The return_type for a generator task must be 'result'")
616
698
 
617
699
  engine = TaskRunEngine[P, R](
618
700
  task=task,
@@ -621,53 +703,95 @@ async def run_task_async(
621
703
  wait_for=wait_for,
622
704
  context=context,
623
705
  )
624
- # This is a context manager that keeps track of the run of the task run.
625
- with engine.start(task_run_id=task_run_id, dependencies=dependencies) as run:
626
- with run.enter_run_context():
627
- run.begin_run()
628
-
629
- while run.is_running():
630
- # enter run context on each loop iteration to ensure the context
631
- # contains the latest task run metadata
632
- with run.enter_run_context():
706
+
707
+ with engine.start(task_run_id=task_run_id, dependencies=dependencies):
708
+ while engine.is_running():
709
+ run_coro_as_sync(engine.wait_until_ready())
710
+ with engine.run_context(), engine.transaction_context() as txn:
711
+ # TODO: generators should default to commit_mode=OFF
712
+ # because they are dynamic by definition
713
+ # for now we just prevent this branch explicitly
714
+ if False and txn.is_committed():
715
+ txn.read()
716
+ else:
717
+ call_args, call_kwargs = parameters_to_args_kwargs(
718
+ task.fn, engine.parameters or {}
719
+ )
720
+ gen = task.fn(*call_args, **call_kwargs)
633
721
  try:
634
- # This is where the task is actually run.
635
- with timeout_async(seconds=run.task.timeout_seconds):
636
- call_args, call_kwargs = parameters_to_args_kwargs(
637
- task.fn, run.parameters or {}
638
- )
639
- run.logger.debug(
640
- f"Executing task {task.name!r} for task run {run.task_run.name!r}..."
641
- )
642
- result_factory = getattr(
643
- TaskRunContext.get(), "result_factory", None
644
- )
645
- with transaction(
646
- key=run.compute_transaction_key(),
647
- store=ResultFactoryStore(result_factory=result_factory),
648
- ) as txn:
649
- if txn.is_committed():
650
- result = txn.read()
651
- else:
652
- result = await task.fn(*call_args, **call_kwargs) # type: ignore
653
-
654
- # If the task run is successful, finalize it.
655
- # do this within the transaction lifecycle
656
- # in order to get the proper result serialization
657
- run.handle_success(result, transaction=txn)
658
-
659
- except TimeoutError as exc:
660
- run.handle_timeout(exc)
661
- except Exception as exc:
662
- run.handle_exception(exc)
663
-
664
- if run.state.is_final():
665
- for hook in run.get_hooks(run.state, as_async=True):
666
- await hook()
667
-
668
- if return_type == "state":
669
- return run.state
670
- return run.result()
722
+ while True:
723
+ gen_result = next(gen)
724
+ # link the current state to the result for dependency tracking
725
+ #
726
+ # TODO: this could grow the task_run_result
727
+ # dictionary in an unbounded way, so finding a
728
+ # way to periodically clean it up (using
729
+ # weakrefs or similar) would be good
730
+ link_state_to_result(engine.state, gen_result)
731
+ yield gen_result
732
+ except StopIteration as exc:
733
+ engine.handle_success(exc.value, transaction=txn)
734
+ except GeneratorExit as exc:
735
+ engine.handle_success(None, transaction=txn)
736
+ gen.throw(exc)
737
+
738
+ return engine.result()
739
+
740
+
741
+ async def run_generator_task_async(
742
+ task: Task[P, R],
743
+ task_run_id: Optional[UUID] = None,
744
+ task_run: Optional[TaskRun] = None,
745
+ parameters: Optional[Dict[str, Any]] = None,
746
+ wait_for: Optional[Iterable[PrefectFuture]] = None,
747
+ return_type: Literal["state", "result"] = "result",
748
+ dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
749
+ context: Optional[Dict[str, Any]] = None,
750
+ ) -> AsyncGenerator[R, None]:
751
+ if return_type != "result":
752
+ raise ValueError("The return_type for a generator task must be 'result'")
753
+ engine = TaskRunEngine[P, R](
754
+ task=task,
755
+ parameters=parameters,
756
+ task_run=task_run,
757
+ wait_for=wait_for,
758
+ context=context,
759
+ )
760
+
761
+ with engine.start(task_run_id=task_run_id, dependencies=dependencies):
762
+ while engine.is_running():
763
+ await engine.wait_until_ready()
764
+ with engine.run_context(), engine.transaction_context() as txn:
765
+ # TODO: generators should default to commit_mode=OFF
766
+ # because they are dynamic by definition
767
+ # for now we just prevent this branch explicitly
768
+ if False and txn.is_committed():
769
+ txn.read()
770
+ else:
771
+ call_args, call_kwargs = parameters_to_args_kwargs(
772
+ task.fn, engine.parameters or {}
773
+ )
774
+ gen = task.fn(*call_args, **call_kwargs)
775
+ try:
776
+ while True:
777
+ # can't use anext in Python < 3.10
778
+ gen_result = await gen.__anext__()
779
+ # link the current state to the result for dependency tracking
780
+ #
781
+ # TODO: this could grow the task_run_result
782
+ # dictionary in an unbounded way, so finding a
783
+ # way to periodically clean it up (using
784
+ # weakrefs or similar) would be good
785
+ link_state_to_result(engine.state, gen_result)
786
+ yield gen_result
787
+ except (StopAsyncIteration, GeneratorExit) as exc:
788
+ engine.handle_success(None, transaction=txn)
789
+ if isinstance(exc, GeneratorExit):
790
+ gen.throw(exc)
791
+
792
+ # async generators can't return, but we can raise failures here
793
+ if engine.state.is_failed():
794
+ engine.result()
671
795
 
672
796
 
673
797
  def run_task(
@@ -709,7 +833,11 @@ def run_task(
709
833
  dependencies=dependencies,
710
834
  context=context,
711
835
  )
712
- if task.isasync:
836
+ if task.isasync and task.isgenerator:
837
+ return run_generator_task_async(**kwargs)
838
+ elif task.isgenerator:
839
+ return run_generator_task_sync(**kwargs)
840
+ elif task.isasync:
713
841
  return run_task_async(**kwargs)
714
842
  else:
715
843
  return run_task_sync(**kwargs)