prefect-client 2.17.1__py3-none-any.whl → 2.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. prefect/_internal/compatibility/deprecated.py +2 -0
  2. prefect/_internal/pydantic/_compat.py +1 -0
  3. prefect/_internal/pydantic/utilities/field_validator.py +25 -10
  4. prefect/_internal/pydantic/utilities/model_dump.py +1 -1
  5. prefect/_internal/pydantic/utilities/model_validate.py +1 -1
  6. prefect/_internal/pydantic/utilities/model_validator.py +11 -3
  7. prefect/_internal/schemas/fields.py +31 -12
  8. prefect/_internal/schemas/validators.py +0 -6
  9. prefect/_version.py +97 -38
  10. prefect/blocks/abstract.py +34 -1
  11. prefect/blocks/core.py +1 -1
  12. prefect/blocks/notifications.py +16 -7
  13. prefect/blocks/system.py +2 -3
  14. prefect/client/base.py +10 -5
  15. prefect/client/orchestration.py +405 -85
  16. prefect/client/schemas/actions.py +4 -3
  17. prefect/client/schemas/objects.py +6 -5
  18. prefect/client/schemas/schedules.py +2 -6
  19. prefect/client/schemas/sorting.py +9 -0
  20. prefect/client/utilities.py +25 -3
  21. prefect/concurrency/asyncio.py +11 -5
  22. prefect/concurrency/events.py +3 -3
  23. prefect/concurrency/services.py +1 -1
  24. prefect/concurrency/sync.py +9 -5
  25. prefect/deployments/__init__.py +0 -2
  26. prefect/deployments/base.py +2 -144
  27. prefect/deployments/deployments.py +29 -20
  28. prefect/deployments/runner.py +36 -28
  29. prefect/deployments/steps/core.py +3 -3
  30. prefect/deprecated/packaging/serializers.py +5 -4
  31. prefect/engine.py +3 -1
  32. prefect/events/__init__.py +45 -0
  33. prefect/events/actions.py +250 -18
  34. prefect/events/cli/automations.py +201 -0
  35. prefect/events/clients.py +179 -21
  36. prefect/events/filters.py +30 -3
  37. prefect/events/instrument.py +40 -40
  38. prefect/events/related.py +2 -1
  39. prefect/events/schemas/automations.py +126 -8
  40. prefect/events/schemas/deployment_triggers.py +23 -277
  41. prefect/events/schemas/events.py +7 -7
  42. prefect/events/utilities.py +3 -1
  43. prefect/events/worker.py +21 -8
  44. prefect/exceptions.py +1 -1
  45. prefect/flows.py +33 -18
  46. prefect/input/actions.py +9 -9
  47. prefect/input/run_input.py +49 -37
  48. prefect/logging/__init__.py +2 -2
  49. prefect/logging/loggers.py +64 -1
  50. prefect/new_flow_engine.py +293 -0
  51. prefect/new_task_engine.py +374 -0
  52. prefect/results.py +32 -12
  53. prefect/runner/runner.py +3 -2
  54. prefect/serializers.py +62 -31
  55. prefect/server/api/collections_data/views/aggregate-worker-metadata.json +44 -3
  56. prefect/settings.py +32 -10
  57. prefect/states.py +25 -19
  58. prefect/tasks.py +17 -0
  59. prefect/types/__init__.py +90 -0
  60. prefect/utilities/asyncutils.py +37 -0
  61. prefect/utilities/engine.py +6 -4
  62. prefect/utilities/pydantic.py +34 -15
  63. prefect/utilities/schema_tools/hydration.py +88 -19
  64. prefect/utilities/schema_tools/validation.py +1 -1
  65. prefect/variables.py +4 -4
  66. {prefect_client-2.17.1.dist-info → prefect_client-2.18.1.dist-info}/METADATA +1 -1
  67. {prefect_client-2.17.1.dist-info → prefect_client-2.18.1.dist-info}/RECORD +71 -67
  68. /prefect/{concurrency/common.py → events/cli/__init__.py} +0 -0
  69. {prefect_client-2.17.1.dist-info → prefect_client-2.18.1.dist-info}/LICENSE +0 -0
  70. {prefect_client-2.17.1.dist-info → prefect_client-2.18.1.dist-info}/WHEEL +0 -0
  71. {prefect_client-2.17.1.dist-info → prefect_client-2.18.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,293 @@
1
+ import asyncio
2
+ from contextlib import asynccontextmanager
3
+ from dataclasses import dataclass
4
+ from typing import (
5
+ Any,
6
+ Coroutine,
7
+ Dict,
8
+ Generic,
9
+ Iterable,
10
+ Literal,
11
+ Optional,
12
+ TypeVar,
13
+ Union,
14
+ cast,
15
+ )
16
+
17
+ import anyio
18
+ from typing_extensions import ParamSpec
19
+
20
+ from prefect import Flow, Task, get_client
21
+ from prefect.client.orchestration import PrefectClient
22
+ from prefect.client.schemas import FlowRun, TaskRun
23
+ from prefect.client.schemas.filters import FlowRunFilter
24
+ from prefect.client.schemas.sorting import FlowRunSort
25
+ from prefect.context import FlowRunContext
26
+ from prefect.futures import PrefectFuture, resolve_futures_to_states
27
+ from prefect.logging.loggers import flow_run_logger
28
+ from prefect.results import ResultFactory
29
+ from prefect.states import (
30
+ Pending,
31
+ Running,
32
+ State,
33
+ exception_to_failed_state,
34
+ return_value_to_state,
35
+ )
36
+ from prefect.utilities.asyncutils import A, Async
37
+ from prefect.utilities.engine import (
38
+ _dynamic_key_for_task_run,
39
+ _resolve_custom_flow_run_name,
40
+ collect_task_run_inputs,
41
+ propose_state,
42
+ )
43
+
44
+ P = ParamSpec("P")
45
+ R = TypeVar("R")
46
+
47
+
48
+ @dataclass
49
+ class FlowRunEngine(Generic[P, R]):
50
+ flow: Flow[P, Coroutine[Any, Any, R]]
51
+ parameters: Optional[Dict[str, Any]] = None
52
+ flow_run: Optional[FlowRun] = None
53
+ _is_started: bool = False
54
+ _client: Optional[PrefectClient] = None
55
+ short_circuit: bool = False
56
+
57
+ def __post_init__(self):
58
+ if self.parameters is None:
59
+ self.parameters = {}
60
+
61
+ @property
62
+ def client(self) -> PrefectClient:
63
+ if not self._is_started or self._client is None:
64
+ raise RuntimeError("Engine has not started.")
65
+ return self._client
66
+
67
+ @property
68
+ def state(self) -> State:
69
+ return self.flow_run.state # type: ignore
70
+
71
+ async def begin_run(self) -> State:
72
+ new_state = Running()
73
+ state = await self.set_state(new_state)
74
+ while state.is_pending():
75
+ await asyncio.sleep(1)
76
+ state = await self.set_state(new_state)
77
+ return state
78
+
79
+ async def set_state(self, state: State) -> State:
80
+ """ """
81
+ # prevents any state-setting activity
82
+ if self.short_circuit:
83
+ return self.state
84
+
85
+ state = await propose_state(self.client, state, flow_run_id=self.flow_run.id) # type: ignore
86
+ self.flow_run.state = state # type: ignore
87
+ self.flow_run.state_name = state.name # type: ignore
88
+ self.flow_run.state_type = state.type # type: ignore
89
+ return state
90
+
91
+ async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
92
+ return await self.state.result(raise_on_failure=raise_on_failure, fetch=True)
93
+
94
+ async def handle_success(self, result: R) -> R:
95
+ result_factory = getattr(FlowRunContext.get(), "result_factory", None)
96
+ terminal_state = await return_value_to_state(
97
+ await resolve_futures_to_states(result),
98
+ result_factory=result_factory,
99
+ )
100
+ await self.set_state(terminal_state)
101
+ return result
102
+
103
+ async def handle_exception(
104
+ self,
105
+ exc: Exception,
106
+ msg: Optional[str] = None,
107
+ result_factory: Optional[ResultFactory] = None,
108
+ ) -> State:
109
+ context = FlowRunContext.get()
110
+ state = await exception_to_failed_state(
111
+ exc,
112
+ message=msg or "Flow run encountered an exception:",
113
+ result_factory=result_factory or getattr(context, "result_factory", None),
114
+ )
115
+ state = await self.set_state(state)
116
+ if self.state.is_scheduled():
117
+ state = await self.set_state(Running())
118
+ return state
119
+
120
+ async def create_subflow_task_run(
121
+ self, client: PrefectClient, context: FlowRunContext
122
+ ) -> TaskRun:
123
+ dummy_task = Task(
124
+ name=self.flow.name, fn=self.flow.fn, version=self.flow.version
125
+ )
126
+ task_inputs = {
127
+ k: await collect_task_run_inputs(v) for k, v in self.parameters.items()
128
+ }
129
+ parent_task_run = await client.create_task_run(
130
+ task=dummy_task,
131
+ flow_run_id=(
132
+ context.flow_run.id if getattr(context, "flow_run", None) else None
133
+ ),
134
+ dynamic_key=_dynamic_key_for_task_run(context, dummy_task),
135
+ task_inputs=task_inputs,
136
+ state=Pending(),
137
+ )
138
+ return parent_task_run
139
+
140
+ async def get_most_recent_flow_run_for_parent_task_run(
141
+ self, client: PrefectClient, parent_task_run: TaskRun
142
+ ) -> "Union[FlowRun, None]":
143
+ """
144
+ Get the most recent flow run associated with the provided parent task run.
145
+
146
+ Args:
147
+ - An orchestration client
148
+ - The parent task run to get the most recent flow run for
149
+
150
+ Returns:
151
+ The most recent flow run associated with the parent task run or `None` if
152
+ no flow runs are found
153
+ """
154
+ flow_runs = await client.read_flow_runs(
155
+ flow_run_filter=FlowRunFilter(
156
+ parent_task_run_id={"any_": [parent_task_run.id]}
157
+ ),
158
+ sort=FlowRunSort.EXPECTED_START_TIME_ASC,
159
+ )
160
+ return flow_runs[-1] if flow_runs else None
161
+
162
+ async def create_flow_run(self, client: PrefectClient) -> FlowRun:
163
+ flow_run_ctx = FlowRunContext.get()
164
+
165
+ parent_task_run = None
166
+ # this is a subflow run
167
+ if flow_run_ctx:
168
+ parent_task_run = await self.create_subflow_task_run(
169
+ client=client, context=flow_run_ctx
170
+ )
171
+ # If the parent task run already completed, return the last flow run
172
+ # associated with the parent task run. This prevents rerunning a completed
173
+ # flow run when the parent task run is rerun.
174
+ most_recent_flow_run = (
175
+ await self.get_most_recent_flow_run_for_parent_task_run(
176
+ client=client, parent_task_run=parent_task_run
177
+ )
178
+ )
179
+ if most_recent_flow_run:
180
+ return most_recent_flow_run
181
+
182
+ try:
183
+ flow_run_name = _resolve_custom_flow_run_name(
184
+ flow=self.flow, parameters=self.parameters
185
+ )
186
+ except TypeError:
187
+ flow_run_name = None
188
+
189
+ flow_run = await client.create_flow_run(
190
+ flow=self.flow,
191
+ name=flow_run_name,
192
+ parameters=self.flow.serialize_parameters(self.parameters),
193
+ state=Pending(),
194
+ parent_task_run_id=getattr(parent_task_run, "id", None),
195
+ )
196
+ return flow_run
197
+
198
+ @asynccontextmanager
199
+ async def enter_run_context(self, client: Optional[PrefectClient] = None):
200
+ if client is None:
201
+ client = self.client
202
+
203
+ self.flow_run = await client.read_flow_run(self.flow_run.id)
204
+
205
+ with FlowRunContext(
206
+ flow=self.flow,
207
+ log_prints=self.flow.log_prints or False,
208
+ flow_run=self.flow_run,
209
+ parameters=self.parameters,
210
+ client=client,
211
+ background_tasks=anyio.create_task_group(),
212
+ result_factory=await ResultFactory.from_flow(self.flow),
213
+ task_runner=self.flow.task_runner,
214
+ ):
215
+ self.logger = flow_run_logger(flow_run=self.flow_run, flow=self.flow)
216
+ yield
217
+
218
+ @asynccontextmanager
219
+ async def start(self):
220
+ """
221
+ Enters a client context and creates a flow run if needed.
222
+ """
223
+ async with get_client() as client:
224
+ self._client = client
225
+ self._is_started = True
226
+
227
+ if not self.flow_run:
228
+ self.flow_run = await self.create_flow_run(client)
229
+
230
+ # validate prior to context so that context receives validated params
231
+ if self.flow.should_validate_parameters:
232
+ try:
233
+ self.parameters = self.flow.validate_parameters(self.parameters)
234
+ except Exception as exc:
235
+ await self.handle_exception(
236
+ exc,
237
+ msg="Validation of flow parameters failed with error",
238
+ result_factory=await ResultFactory.from_flow(self.flow),
239
+ )
240
+ self.short_circuit = True
241
+
242
+ yield self
243
+
244
+ self._is_started = False
245
+ self._client = None
246
+
247
+ def is_running(self) -> bool:
248
+ if getattr(self, "flow_run", None) is None:
249
+ return False
250
+ return getattr(self, "flow_run").state.is_running()
251
+
252
+ def is_pending(self) -> bool:
253
+ if getattr(self, "flow_run", None) is None:
254
+ return False # TODO: handle this differently?
255
+ return getattr(self, "flow_run").state.is_pending()
256
+
257
+
258
+ async def run_flow(
259
+ flow: Task[P, Coroutine[Any, Any, R]],
260
+ flow_run: Optional[FlowRun] = None,
261
+ parameters: Optional[Dict[str, Any]] = None,
262
+ wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
263
+ return_type: Literal["state", "result"] = "result",
264
+ ) -> "Union[R, None]":
265
+ """
266
+ Runs a flow against the API.
267
+
268
+ We will most likely want to use this logic as a wrapper and return a coroutine for type inference.
269
+ """
270
+
271
+ engine = FlowRunEngine[P, R](flow, parameters, flow_run)
272
+ async with engine.start() as run:
273
+ # This is a context manager that keeps track of the state of the flow run.
274
+ await run.begin_run()
275
+
276
+ while run.is_running():
277
+ async with run.enter_run_context():
278
+ try:
279
+ # This is where the flow is actually run.
280
+ if flow.isasync:
281
+ result = cast(R, await flow.fn(**(run.parameters or {}))) # type: ignore
282
+ else:
283
+ result = cast(R, flow.fn(**(run.parameters or {}))) # type: ignore
284
+ # If the flow run is successful, finalize it.
285
+ await run.handle_success(result)
286
+
287
+ except Exception as exc:
288
+ # If the flow fails, and we have retries left, set the flow to retrying.
289
+ await run.handle_exception(exc)
290
+
291
+ if return_type == "state":
292
+ return run.state
293
+ return await run.result()
@@ -0,0 +1,374 @@
1
+ import asyncio
2
+ import logging
3
+ from contextlib import asynccontextmanager
4
+ from dataclasses import dataclass, field
5
+ from typing import (
6
+ Any,
7
+ AsyncGenerator,
8
+ Callable,
9
+ Coroutine,
10
+ Dict,
11
+ Generic,
12
+ Iterable,
13
+ Literal,
14
+ Optional,
15
+ TypeVar,
16
+ Union,
17
+ cast,
18
+ )
19
+ from uuid import uuid4
20
+
21
+ import pendulum
22
+ from typing_extensions import ParamSpec
23
+
24
+ from prefect import Task, get_client
25
+ from prefect.client.orchestration import PrefectClient
26
+ from prefect.client.schemas import TaskRun
27
+ from prefect.client.schemas.objects import TaskRunResult
28
+ from prefect.context import FlowRunContext, TaskRunContext
29
+ from prefect.futures import PrefectFuture, resolve_futures_to_states
30
+ from prefect.logging.loggers import get_logger, task_run_logger
31
+ from prefect.results import ResultFactory
32
+ from prefect.server.schemas.states import State
33
+ from prefect.settings import PREFECT_TASKS_REFRESH_CACHE
34
+ from prefect.states import (
35
+ Pending,
36
+ Retrying,
37
+ Running,
38
+ StateDetails,
39
+ exception_to_crashed_state,
40
+ exception_to_failed_state,
41
+ return_value_to_state,
42
+ )
43
+ from prefect.utilities.asyncutils import A, Async, is_async_fn
44
+ from prefect.utilities.engine import (
45
+ _dynamic_key_for_task_run,
46
+ _get_hook_name,
47
+ _resolve_custom_task_run_name,
48
+ collect_task_run_inputs,
49
+ propose_state,
50
+ )
51
+
52
+
53
+ @asynccontextmanager
54
+ async def timeout(
55
+ delay: Optional[float], *, loop: Optional[asyncio.AbstractEventLoop] = None
56
+ ) -> AsyncGenerator[None, None]:
57
+ loop = loop or asyncio.get_running_loop()
58
+ task = asyncio.current_task(loop=loop)
59
+ timer_handle: Optional[asyncio.TimerHandle] = None
60
+
61
+ if delay is not None and task is not None:
62
+ timer_handle = loop.call_later(delay, task.cancel)
63
+
64
+ try:
65
+ yield
66
+ finally:
67
+ if timer_handle is not None:
68
+ timer_handle.cancel()
69
+
70
+
71
+ P = ParamSpec("P")
72
+ R = TypeVar("R")
73
+
74
+
75
+ @dataclass
76
+ class TaskRunEngine(Generic[P, R]):
77
+ task: Task[P, Coroutine[Any, Any, R]]
78
+ logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
79
+ parameters: Optional[Dict[str, Any]] = None
80
+ task_run: Optional[TaskRun] = None
81
+ retries: int = 0
82
+ _is_started: bool = False
83
+ _client: Optional[PrefectClient] = None
84
+
85
+ def __post_init__(self):
86
+ if self.parameters is None:
87
+ self.parameters = {}
88
+
89
+ @property
90
+ def client(self) -> PrefectClient:
91
+ if not self._is_started or self._client is None:
92
+ raise RuntimeError("Engine has not started.")
93
+ return self._client
94
+
95
+ @property
96
+ def state(self) -> State:
97
+ return self.task_run.state # type: ignore
98
+
99
+ @property
100
+ def can_retry(self) -> bool:
101
+ retry_condition: Optional[ # type: ignore
102
+ Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
103
+ ] = self.task.retry_condition_fn # type: ignore
104
+ return not retry_condition or retry_condition(
105
+ self.task, self.task_run, self.state
106
+ ) # type: ignore
107
+
108
+ async def _run_hooks(self, state: State) -> None:
109
+ """Run the on_failure and on_completion hooks for a task, making sure to
110
+ catch and log any errors that occur.
111
+ """
112
+ task = self.task
113
+ task_run = self.task_run
114
+
115
+ hooks = None
116
+ if state.is_failed() and task.on_failure:
117
+ hooks = task.on_failure
118
+ elif state.is_completed() and task.on_completion:
119
+ hooks = task.on_completion
120
+
121
+ if hooks:
122
+ for hook in hooks:
123
+ hook_name = _get_hook_name(hook)
124
+ try:
125
+ self.logger.info(
126
+ f"Running hook {hook_name!r} in response to entering state"
127
+ f" {state.name!r}"
128
+ )
129
+ if is_async_fn(hook):
130
+ await hook(task=task, task_run=task_run, state=state)
131
+ else:
132
+ hook(task=task, task_run=task_run, state=state)
133
+ except Exception:
134
+ self.logger.error(
135
+ f"An error was encountered while running hook {hook_name!r}",
136
+ exc_info=True,
137
+ )
138
+ else:
139
+ self.logger.info(
140
+ f"Hook {hook_name!r} finished running successfully"
141
+ )
142
+
143
+ def _compute_state_details(
144
+ self, include_cache_expiration: bool = False
145
+ ) -> StateDetails:
146
+ ## setup cache metadata
147
+ task_run_context = TaskRunContext.get()
148
+ cache_key = (
149
+ self.task.cache_key_fn(
150
+ task_run_context,
151
+ self.parameters,
152
+ )
153
+ if self.task.cache_key_fn
154
+ else None
155
+ )
156
+ # Ignore the cached results for a cache key, default = false
157
+ # Setting on task level overrules the Prefect setting (env var)
158
+ refresh_cache = (
159
+ self.task.refresh_cache
160
+ if self.task.refresh_cache is not None
161
+ else PREFECT_TASKS_REFRESH_CACHE.value()
162
+ )
163
+
164
+ if include_cache_expiration:
165
+ cache_expiration = (
166
+ (pendulum.now("utc") + self.task.cache_expiration)
167
+ if self.task.cache_expiration
168
+ else None
169
+ )
170
+ else:
171
+ cache_expiration = None
172
+ return StateDetails(
173
+ cache_key=cache_key,
174
+ refresh_cache=refresh_cache,
175
+ cache_expiration=cache_expiration,
176
+ )
177
+
178
+ async def begin_run(self) -> State:
179
+ state_details = self._compute_state_details()
180
+ new_state = Running(state_details=state_details)
181
+ state = await self.set_state(new_state)
182
+ while state.is_pending():
183
+ await asyncio.sleep(1)
184
+ state = await self.set_state(new_state)
185
+
186
+ async def set_state(self, state: State, force: bool = False) -> State:
187
+ new_state = await propose_state(
188
+ self.client, state, task_run_id=self.task_run.id, force=force
189
+ ) # type: ignore
190
+
191
+ # currently this is a hack to keep a reference to the state object
192
+ # that has an in-memory result attached to it; using the API state
193
+ # could result in losing that reference
194
+ self.task_run.state = new_state
195
+ if new_state.is_final():
196
+ await self._run_hooks(new_state)
197
+ return new_state
198
+
199
+ async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
200
+ return await self.state.result(raise_on_failure=raise_on_failure)
201
+
202
+ async def handle_success(self, result: R) -> R:
203
+ result_factory = getattr(TaskRunContext.get(), "result_factory", None)
204
+ terminal_state = await return_value_to_state(
205
+ await resolve_futures_to_states(result),
206
+ result_factory=result_factory,
207
+ )
208
+ terminal_state.state_details = self._compute_state_details(
209
+ include_cache_expiration=True
210
+ )
211
+ await self.set_state(terminal_state)
212
+ return result
213
+
214
+ async def handle_retry(self, exc: Exception) -> bool:
215
+ """
216
+ If the task has retries left, and the retry condition is met, set the task to retrying.
217
+ - If the task has no retries left, or the retry condition is not met, return False.
218
+ - If the task has retries left, and the retry condition is met, return True.
219
+ """
220
+ if self.retries < self.task.retries and self.can_retry:
221
+ await self.set_state(Retrying(), force=True)
222
+ self.retries = self.retries + 1
223
+ return True
224
+ return False
225
+
226
+ async def handle_exception(self, exc: Exception) -> None:
227
+ # If the task fails, and we have retries left, set the task to retrying.
228
+ if not await self.handle_retry(exc):
229
+ # If the task has no retries left, or the retry condition is not met, set the task to failed.
230
+ context = TaskRunContext.get()
231
+ state = await exception_to_failed_state(
232
+ exc,
233
+ message="Task run encountered an exception",
234
+ result_factory=getattr(context, "result_factory", None),
235
+ )
236
+ await self.set_state(state)
237
+
238
+ async def handle_crash(self, exc: BaseException) -> None:
239
+ state = await exception_to_crashed_state(exc)
240
+ self.logger.error(f"Crash detected! {state.message}")
241
+ self.logger.debug("Crash details:", exc_info=exc)
242
+ await self.set_state(state, force=True)
243
+
244
+ async def create_task_run(self, client: PrefectClient) -> TaskRun:
245
+ flow_run_ctx = FlowRunContext.get()
246
+ try:
247
+ task_run_name = _resolve_custom_task_run_name(self.task, self.parameters)
248
+ except TypeError:
249
+ task_run_name = None
250
+
251
+ # prep input tracking
252
+ task_inputs = {
253
+ k: await collect_task_run_inputs(v) for k, v in self.parameters.items()
254
+ }
255
+
256
+ # anticipate nested runs
257
+ task_run_ctx = TaskRunContext.get()
258
+ if task_run_ctx:
259
+ task_inputs["wait_for"] = [TaskRunResult(id=task_run_ctx.task_run.id)]
260
+
261
+ # TODO: implement wait_for
262
+ # if wait_for:
263
+ # task_inputs["wait_for"] = await collect_task_run_inputs(wait_for)
264
+
265
+ if flow_run_ctx:
266
+ dynamic_key = _dynamic_key_for_task_run(
267
+ context=flow_run_ctx, task=self.task
268
+ )
269
+ else:
270
+ dynamic_key = uuid4().hex
271
+ task_run = await client.create_task_run(
272
+ task=self.task,
273
+ name=task_run_name,
274
+ flow_run_id=(
275
+ getattr(flow_run_ctx.flow_run, "id", None)
276
+ if flow_run_ctx and flow_run_ctx.flow_run
277
+ else None
278
+ ),
279
+ dynamic_key=dynamic_key,
280
+ state=Pending(),
281
+ task_inputs=task_inputs,
282
+ )
283
+ return task_run
284
+
285
+ @asynccontextmanager
286
+ async def enter_run_context(self, client: Optional[PrefectClient] = None):
287
+ if client is None:
288
+ client = self.client
289
+
290
+ self.task_run = await client.read_task_run(self.task_run.id)
291
+
292
+ with TaskRunContext(
293
+ task=self.task,
294
+ log_prints=self.task.log_prints or False,
295
+ task_run=self.task_run,
296
+ parameters=self.parameters,
297
+ result_factory=await ResultFactory.from_autonomous_task(self.task),
298
+ client=client,
299
+ ):
300
+ self.logger = task_run_logger(task_run=self.task_run, task=self.task)
301
+ yield
302
+
303
+ @asynccontextmanager
304
+ async def start(self):
305
+ """
306
+ Enters a client context and creates a task run if needed.
307
+ """
308
+ async with get_client() as client:
309
+ self._client = client
310
+ self._is_started = True
311
+ try:
312
+ if not self.task_run:
313
+ self.task_run = await self.create_task_run(client)
314
+
315
+ yield self
316
+ except Exception:
317
+ # regular exceptions are caught and re-raised to the user
318
+ raise
319
+ except BaseException as exc:
320
+ # BaseExceptions are caught and handled as crashes
321
+ await self.handle_crash(exc)
322
+ raise
323
+ finally:
324
+ self._is_started = False
325
+ self._client = None
326
+
327
+ def is_running(self) -> bool:
328
+ if getattr(self, "task_run", None) is None:
329
+ return False
330
+ return getattr(self, "task_run").state.is_running()
331
+
332
+ def is_pending(self) -> bool:
333
+ if getattr(self, "task_run", None) is None:
334
+ return False # TODO: handle this differently?
335
+ return getattr(self, "task_run").state.is_pending()
336
+
337
+
338
+ async def run_task(
339
+ task: Task[P, Coroutine[Any, Any, R]],
340
+ task_run: Optional[TaskRun] = None,
341
+ parameters: Optional[Dict[str, Any]] = None,
342
+ wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
343
+ return_type: Literal["state", "result"] = "result",
344
+ ) -> "Union[R, State, None]":
345
+ """
346
+ Runs a task against the API.
347
+
348
+ We will most likely want to use this logic as a wrapper and return a coroutine for type inference.
349
+ """
350
+ engine = TaskRunEngine[P, R](task=task, parameters=parameters, task_run=task_run)
351
+ async with engine.start() as run:
352
+ # This is a context manager that keeps track of the run of the task run.
353
+ await run.begin_run()
354
+
355
+ while run.is_running():
356
+ async with run.enter_run_context():
357
+ try:
358
+ # This is where the task is actually run.
359
+ async with timeout(run.task.timeout_seconds):
360
+ if task.isasync:
361
+ result = cast(R, await task.fn(**(parameters or {}))) # type: ignore
362
+ else:
363
+ result = cast(R, task.fn(**(parameters or {}))) # type: ignore
364
+ # If the task run is successful, finalize it.
365
+ await run.handle_success(result)
366
+ if return_type == "result":
367
+ return result
368
+
369
+ except Exception as exc:
370
+ await run.handle_exception(exc)
371
+
372
+ if return_type == "state":
373
+ return run.state
374
+ return await run.result()