prefect-client 3.0.0rc12__py3-none-any.whl → 3.0.0rc14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/blocks/core.py +132 -4
- prefect/blocks/notifications.py +26 -3
- prefect/client/base.py +30 -24
- prefect/client/orchestration.py +121 -47
- prefect/client/utilities.py +4 -4
- prefect/concurrency/asyncio.py +48 -7
- prefect/concurrency/context.py +24 -0
- prefect/concurrency/services.py +24 -8
- prefect/concurrency/sync.py +30 -3
- prefect/context.py +83 -23
- prefect/events/clients.py +59 -4
- prefect/events/worker.py +9 -2
- prefect/flow_engine.py +6 -3
- prefect/flows.py +166 -8
- prefect/futures.py +84 -2
- prefect/profiles.toml +13 -2
- prefect/runner/runner.py +6 -1
- prefect/settings.py +35 -7
- prefect/task_engine.py +870 -291
- prefect/task_runs.py +24 -1
- prefect/task_worker.py +27 -16
- prefect/utilities/callables.py +5 -3
- prefect/utilities/importtools.py +138 -58
- prefect/utilities/schema_tools/validation.py +30 -0
- prefect/utilities/services.py +32 -0
- {prefect_client-3.0.0rc12.dist-info → prefect_client-3.0.0rc14.dist-info}/METADATA +2 -1
- {prefect_client-3.0.0rc12.dist-info → prefect_client-3.0.0rc14.dist-info}/RECORD +30 -29
- {prefect_client-3.0.0rc12.dist-info → prefect_client-3.0.0rc14.dist-info}/LICENSE +0 -0
- {prefect_client-3.0.0rc12.dist-info → prefect_client-3.0.0rc14.dist-info}/WHEEL +0 -0
- {prefect_client-3.0.0rc12.dist-info → prefect_client-3.0.0rc14.dist-info}/top_level.txt +0 -0
prefect/task_engine.py
CHANGED
@@ -3,7 +3,7 @@ import logging
|
|
3
3
|
import threading
|
4
4
|
import time
|
5
5
|
from asyncio import CancelledError
|
6
|
-
from contextlib import ExitStack, contextmanager
|
6
|
+
from contextlib import ExitStack, asynccontextmanager, contextmanager
|
7
7
|
from dataclasses import dataclass, field
|
8
8
|
from functools import wraps
|
9
9
|
from textwrap import dedent
|
@@ -31,12 +31,16 @@ import pendulum
|
|
31
31
|
from typing_extensions import ParamSpec
|
32
32
|
|
33
33
|
from prefect import Task
|
34
|
-
from prefect.client.orchestration import SyncPrefectClient
|
34
|
+
from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
|
35
35
|
from prefect.client.schemas import TaskRun
|
36
36
|
from prefect.client.schemas.objects import State, TaskRunInput
|
37
|
+
from prefect.concurrency.asyncio import concurrency as aconcurrency
|
38
|
+
from prefect.concurrency.context import ConcurrencyContext
|
39
|
+
from prefect.concurrency.sync import concurrency
|
37
40
|
from prefect.context import (
|
38
|
-
|
41
|
+
AsyncClientContext,
|
39
42
|
FlowRunContext,
|
43
|
+
SyncClientContext,
|
40
44
|
TaskRunContext,
|
41
45
|
hydrated_context,
|
42
46
|
)
|
@@ -59,6 +63,7 @@ from prefect.settings import (
|
|
59
63
|
)
|
60
64
|
from prefect.states import (
|
61
65
|
AwaitingRetry,
|
66
|
+
Completed,
|
62
67
|
Failed,
|
63
68
|
Paused,
|
64
69
|
Pending,
|
@@ -77,6 +82,7 @@ from prefect.utilities.engine import (
|
|
77
82
|
_get_hook_name,
|
78
83
|
emit_task_run_state_change_event,
|
79
84
|
link_state_to_result,
|
85
|
+
propose_state,
|
80
86
|
propose_state_sync,
|
81
87
|
resolve_to_final_result,
|
82
88
|
)
|
@@ -86,47 +92,745 @@ from prefect.utilities.timeout import timeout, timeout_async
|
|
86
92
|
P = ParamSpec("P")
|
87
93
|
R = TypeVar("R")
|
88
94
|
|
95
|
+
BACKOFF_MAX = 10
|
96
|
+
|
89
97
|
|
90
98
|
class TaskRunTimeoutError(TimeoutError):
|
91
99
|
"""Raised when a task run exceeds its timeout."""
|
92
100
|
|
93
101
|
|
94
|
-
@dataclass
|
95
|
-
class
|
96
|
-
task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
|
97
|
-
logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
|
98
|
-
parameters: Optional[Dict[str, Any]] = None
|
99
|
-
task_run: Optional[TaskRun] = None
|
100
|
-
retries: int = 0
|
101
|
-
wait_for: Optional[Iterable[PrefectFuture]] = None
|
102
|
-
context: Optional[Dict[str, Any]] = None
|
103
|
-
# holds the return value from the user code
|
104
|
-
_return_value: Union[R, Type[NotSet]] = NotSet
|
105
|
-
# holds the exception raised by the user code, if any
|
106
|
-
_raised: Union[Exception, Type[NotSet]] = NotSet
|
107
|
-
_initial_run_context: Optional[TaskRunContext] = None
|
108
|
-
_is_started: bool = False
|
109
|
-
|
110
|
-
|
111
|
-
|
102
|
+
@dataclass
|
103
|
+
class BaseTaskRunEngine(Generic[P, R]):
|
104
|
+
task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
|
105
|
+
logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
|
106
|
+
parameters: Optional[Dict[str, Any]] = None
|
107
|
+
task_run: Optional[TaskRun] = None
|
108
|
+
retries: int = 0
|
109
|
+
wait_for: Optional[Iterable[PrefectFuture]] = None
|
110
|
+
context: Optional[Dict[str, Any]] = None
|
111
|
+
# holds the return value from the user code
|
112
|
+
_return_value: Union[R, Type[NotSet]] = NotSet
|
113
|
+
# holds the exception raised by the user code, if any
|
114
|
+
_raised: Union[Exception, Type[NotSet]] = NotSet
|
115
|
+
_initial_run_context: Optional[TaskRunContext] = None
|
116
|
+
_is_started: bool = False
|
117
|
+
_task_name_set: bool = False
|
118
|
+
_last_event: Optional[PrefectEvent] = None
|
119
|
+
|
120
|
+
def __post_init__(self):
|
121
|
+
if self.parameters is None:
|
122
|
+
self.parameters = {}
|
123
|
+
|
124
|
+
@property
|
125
|
+
def state(self) -> State:
|
126
|
+
if not self.task_run:
|
127
|
+
raise ValueError("Task run is not set")
|
128
|
+
return self.task_run.state
|
129
|
+
|
130
|
+
def is_cancelled(self) -> bool:
|
131
|
+
if (
|
132
|
+
self.context
|
133
|
+
and "cancel_event" in self.context
|
134
|
+
and isinstance(self.context["cancel_event"], threading.Event)
|
135
|
+
):
|
136
|
+
return self.context["cancel_event"].is_set()
|
137
|
+
return False
|
138
|
+
|
139
|
+
def compute_transaction_key(self) -> Optional[str]:
|
140
|
+
key = None
|
141
|
+
if self.task.cache_policy:
|
142
|
+
flow_run_context = FlowRunContext.get()
|
143
|
+
task_run_context = TaskRunContext.get()
|
144
|
+
|
145
|
+
if flow_run_context:
|
146
|
+
parameters = flow_run_context.parameters
|
147
|
+
else:
|
148
|
+
parameters = None
|
149
|
+
|
150
|
+
key = self.task.cache_policy.compute_key(
|
151
|
+
task_ctx=task_run_context,
|
152
|
+
inputs=self.parameters,
|
153
|
+
flow_parameters=parameters,
|
154
|
+
)
|
155
|
+
elif self.task.result_storage_key is not None:
|
156
|
+
key = _format_user_supplied_storage_key(self.task.result_storage_key)
|
157
|
+
return key
|
158
|
+
|
159
|
+
def _resolve_parameters(self):
|
160
|
+
if not self.parameters:
|
161
|
+
return {}
|
162
|
+
|
163
|
+
resolved_parameters = {}
|
164
|
+
for parameter, value in self.parameters.items():
|
165
|
+
try:
|
166
|
+
resolved_parameters[parameter] = visit_collection(
|
167
|
+
value,
|
168
|
+
visit_fn=resolve_to_final_result,
|
169
|
+
return_data=True,
|
170
|
+
max_depth=-1,
|
171
|
+
remove_annotations=True,
|
172
|
+
context={},
|
173
|
+
)
|
174
|
+
except UpstreamTaskError:
|
175
|
+
raise
|
176
|
+
except Exception as exc:
|
177
|
+
raise PrefectException(
|
178
|
+
f"Failed to resolve inputs in parameter {parameter!r}. If your"
|
179
|
+
" parameter type is not supported, consider using the `quote`"
|
180
|
+
" annotation to skip resolution of inputs."
|
181
|
+
) from exc
|
182
|
+
|
183
|
+
self.parameters = resolved_parameters
|
184
|
+
|
185
|
+
def _wait_for_dependencies(self):
|
186
|
+
if not self.wait_for:
|
187
|
+
return
|
188
|
+
|
189
|
+
visit_collection(
|
190
|
+
self.wait_for,
|
191
|
+
visit_fn=resolve_to_final_result,
|
192
|
+
return_data=False,
|
193
|
+
max_depth=-1,
|
194
|
+
remove_annotations=True,
|
195
|
+
context={"current_task_run": self.task_run, "current_task": self.task},
|
196
|
+
)
|
197
|
+
|
198
|
+
def record_terminal_state_timing(self, state: State) -> None:
|
199
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
200
|
+
if self.task_run.start_time and not self.task_run.end_time:
|
201
|
+
self.task_run.end_time = state.timestamp
|
202
|
+
|
203
|
+
if self.task_run.state.is_running():
|
204
|
+
self.task_run.total_run_time += (
|
205
|
+
state.timestamp - self.task_run.state.timestamp
|
206
|
+
)
|
207
|
+
|
208
|
+
def is_running(self) -> bool:
|
209
|
+
"""Whether or not the engine is currently running a task."""
|
210
|
+
if (task_run := getattr(self, "task_run", None)) is None:
|
211
|
+
return False
|
212
|
+
return task_run.state.is_running() or task_run.state.is_scheduled()
|
213
|
+
|
214
|
+
def log_finished_message(self):
|
215
|
+
# If debugging, use the more complete `repr` than the usual `str` description
|
216
|
+
display_state = repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
|
217
|
+
level = logging.INFO if self.state.is_completed() else logging.ERROR
|
218
|
+
msg = f"Finished in state {display_state}"
|
219
|
+
if self.state.is_pending():
|
220
|
+
msg += (
|
221
|
+
"\nPlease wait for all submitted tasks to complete"
|
222
|
+
" before exiting your flow by calling `.wait()` on the "
|
223
|
+
"`PrefectFuture` returned from your `.submit()` calls."
|
224
|
+
)
|
225
|
+
msg += dedent(
|
226
|
+
"""
|
227
|
+
|
228
|
+
Example:
|
229
|
+
|
230
|
+
from prefect import flow, task
|
231
|
+
|
232
|
+
@task
|
233
|
+
def say_hello(name):
|
234
|
+
print f"Hello, {name}!"
|
235
|
+
|
236
|
+
@flow
|
237
|
+
def example_flow():
|
238
|
+
future = say_hello.submit(name="Marvin)
|
239
|
+
future.wait()
|
240
|
+
|
241
|
+
example_flow()
|
242
|
+
"""
|
243
|
+
)
|
244
|
+
self.logger.log(
|
245
|
+
level=level,
|
246
|
+
msg=msg,
|
247
|
+
)
|
248
|
+
|
249
|
+
def handle_rollback(self, txn: Transaction) -> None:
|
250
|
+
assert self.task_run is not None
|
251
|
+
|
252
|
+
rolled_back_state = Completed(
|
253
|
+
name="RolledBack",
|
254
|
+
message="Task rolled back as part of transaction",
|
255
|
+
)
|
256
|
+
|
257
|
+
self._last_event = emit_task_run_state_change_event(
|
258
|
+
task_run=self.task_run,
|
259
|
+
initial_state=self.state,
|
260
|
+
validated_state=rolled_back_state,
|
261
|
+
follows=self._last_event,
|
262
|
+
)
|
263
|
+
|
264
|
+
|
265
|
+
@dataclass
|
266
|
+
class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
267
|
+
_client: Optional[SyncPrefectClient] = None
|
268
|
+
|
269
|
+
@property
|
270
|
+
def client(self) -> SyncPrefectClient:
|
271
|
+
if not self._is_started or self._client is None:
|
272
|
+
raise RuntimeError("Engine has not started.")
|
273
|
+
return self._client
|
274
|
+
|
275
|
+
def can_retry(self, exc: Exception) -> bool:
|
276
|
+
retry_condition: Optional[
|
277
|
+
Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
|
278
|
+
] = self.task.retry_condition_fn
|
279
|
+
if not self.task_run:
|
280
|
+
raise ValueError("Task run is not set")
|
281
|
+
try:
|
282
|
+
self.logger.debug(
|
283
|
+
f"Running `retry_condition_fn` check {retry_condition!r} for task"
|
284
|
+
f" {self.task.name!r}"
|
285
|
+
)
|
286
|
+
state = Failed(
|
287
|
+
data=exc,
|
288
|
+
message=f"Task run encountered unexpected exception: {repr(exc)}",
|
289
|
+
)
|
290
|
+
if inspect.iscoroutinefunction(retry_condition):
|
291
|
+
should_retry = run_coro_as_sync(
|
292
|
+
retry_condition(self.task, self.task_run, state)
|
293
|
+
)
|
294
|
+
elif inspect.isfunction(retry_condition):
|
295
|
+
should_retry = retry_condition(self.task, self.task_run, state)
|
296
|
+
else:
|
297
|
+
should_retry = not retry_condition
|
298
|
+
return should_retry
|
299
|
+
except Exception:
|
300
|
+
self.logger.error(
|
301
|
+
(
|
302
|
+
"An error was encountered while running `retry_condition_fn` check"
|
303
|
+
f" '{retry_condition!r}' for task {self.task.name!r}"
|
304
|
+
),
|
305
|
+
exc_info=True,
|
306
|
+
)
|
307
|
+
return False
|
308
|
+
|
309
|
+
def call_hooks(self, state: Optional[State] = None):
|
310
|
+
if state is None:
|
311
|
+
state = self.state
|
312
|
+
task = self.task
|
313
|
+
task_run = self.task_run
|
314
|
+
|
315
|
+
if not task_run:
|
316
|
+
raise ValueError("Task run is not set")
|
317
|
+
|
318
|
+
if state.is_failed() and task.on_failure_hooks:
|
319
|
+
hooks = task.on_failure_hooks
|
320
|
+
elif state.is_completed() and task.on_completion_hooks:
|
321
|
+
hooks = task.on_completion_hooks
|
322
|
+
else:
|
323
|
+
hooks = None
|
324
|
+
|
325
|
+
for hook in hooks or []:
|
326
|
+
hook_name = _get_hook_name(hook)
|
327
|
+
|
328
|
+
try:
|
329
|
+
self.logger.info(
|
330
|
+
f"Running hook {hook_name!r} in response to entering state"
|
331
|
+
f" {state.name!r}"
|
332
|
+
)
|
333
|
+
result = hook(task, task_run, state)
|
334
|
+
if inspect.isawaitable(result):
|
335
|
+
run_coro_as_sync(result)
|
336
|
+
except Exception:
|
337
|
+
self.logger.error(
|
338
|
+
f"An error was encountered while running hook {hook_name!r}",
|
339
|
+
exc_info=True,
|
340
|
+
)
|
341
|
+
else:
|
342
|
+
self.logger.info(f"Hook {hook_name!r} finished running successfully")
|
343
|
+
|
344
|
+
def begin_run(self):
|
345
|
+
try:
|
346
|
+
self._resolve_parameters()
|
347
|
+
self._wait_for_dependencies()
|
348
|
+
except UpstreamTaskError as upstream_exc:
|
349
|
+
state = self.set_state(
|
350
|
+
Pending(
|
351
|
+
name="NotReady",
|
352
|
+
message=str(upstream_exc),
|
353
|
+
),
|
354
|
+
# if orchestrating a run already in a pending state, force orchestration to
|
355
|
+
# update the state name
|
356
|
+
force=self.state.is_pending(),
|
357
|
+
)
|
358
|
+
return
|
359
|
+
|
360
|
+
new_state = Running()
|
361
|
+
|
362
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
363
|
+
self.task_run.start_time = new_state.timestamp
|
364
|
+
self.task_run.run_count += 1
|
365
|
+
|
366
|
+
flow_run_context = FlowRunContext.get()
|
367
|
+
if flow_run_context:
|
368
|
+
# Carry forward any task run information from the flow run
|
369
|
+
flow_run = flow_run_context.flow_run
|
370
|
+
self.task_run.flow_run_run_count = flow_run.run_count
|
371
|
+
|
372
|
+
state = self.set_state(new_state)
|
373
|
+
|
374
|
+
# TODO: this is temporary until the API stops rejecting state transitions
|
375
|
+
# and the client / transaction store becomes the source of truth
|
376
|
+
# this is a bandaid caused by the API storing a Completed state with a bad
|
377
|
+
# result reference that no longer exists
|
378
|
+
if state.is_completed():
|
379
|
+
try:
|
380
|
+
state.result(retry_result_failure=False, _sync=True)
|
381
|
+
except Exception:
|
382
|
+
state = self.set_state(new_state, force=True)
|
383
|
+
|
384
|
+
backoff_count = 0
|
385
|
+
|
386
|
+
# TODO: Could this listen for state change events instead of polling?
|
387
|
+
while state.is_pending() or state.is_paused():
|
388
|
+
if backoff_count < BACKOFF_MAX:
|
389
|
+
backoff_count += 1
|
390
|
+
interval = clamped_poisson_interval(
|
391
|
+
average_interval=backoff_count, clamping_factor=0.3
|
392
|
+
)
|
393
|
+
time.sleep(interval)
|
394
|
+
state = self.set_state(new_state)
|
395
|
+
|
396
|
+
def set_state(self, state: State, force: bool = False) -> State:
|
397
|
+
last_state = self.state
|
398
|
+
if not self.task_run:
|
399
|
+
raise ValueError("Task run is not set")
|
400
|
+
|
401
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
402
|
+
self.task_run.state = new_state = state
|
403
|
+
|
404
|
+
# Ensure that the state_details are populated with the current run IDs
|
405
|
+
new_state.state_details.task_run_id = self.task_run.id
|
406
|
+
new_state.state_details.flow_run_id = self.task_run.flow_run_id
|
407
|
+
|
408
|
+
# Predictively update the de-normalized task_run.state_* attributes
|
409
|
+
self.task_run.state_id = new_state.id
|
410
|
+
self.task_run.state_type = new_state.type
|
411
|
+
self.task_run.state_name = new_state.name
|
412
|
+
else:
|
413
|
+
try:
|
414
|
+
new_state = propose_state_sync(
|
415
|
+
self.client, state, task_run_id=self.task_run.id, force=force
|
416
|
+
)
|
417
|
+
except Pause as exc:
|
418
|
+
# We shouldn't get a pause signal without a state, but if this happens,
|
419
|
+
# just use a Paused state to assume an in-process pause.
|
420
|
+
new_state = exc.state if exc.state else Paused()
|
421
|
+
if new_state.state_details.pause_reschedule:
|
422
|
+
# If we're being asked to pause and reschedule, we should exit the
|
423
|
+
# task and expect to be resumed later.
|
424
|
+
raise
|
425
|
+
|
426
|
+
# currently this is a hack to keep a reference to the state object
|
427
|
+
# that has an in-memory result attached to it; using the API state
|
428
|
+
# could result in losing that reference
|
429
|
+
self.task_run.state = new_state
|
430
|
+
|
431
|
+
# emit a state change event
|
432
|
+
self._last_event = emit_task_run_state_change_event(
|
433
|
+
task_run=self.task_run,
|
434
|
+
initial_state=last_state,
|
435
|
+
validated_state=self.task_run.state,
|
436
|
+
follows=self._last_event,
|
437
|
+
)
|
438
|
+
|
439
|
+
return new_state
|
440
|
+
|
441
|
+
def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
|
442
|
+
if self._return_value is not NotSet:
|
443
|
+
# if the return value is a BaseResult, we need to fetch it
|
444
|
+
if isinstance(self._return_value, BaseResult):
|
445
|
+
_result = self._return_value.get()
|
446
|
+
if inspect.isawaitable(_result):
|
447
|
+
_result = run_coro_as_sync(_result)
|
448
|
+
return _result
|
449
|
+
|
450
|
+
# otherwise, return the value as is
|
451
|
+
return self._return_value
|
452
|
+
|
453
|
+
if self._raised is not NotSet:
|
454
|
+
# if the task raised an exception, raise it
|
455
|
+
if raise_on_failure:
|
456
|
+
raise self._raised
|
457
|
+
|
458
|
+
# otherwise, return the exception
|
459
|
+
return self._raised
|
460
|
+
|
461
|
+
def handle_success(self, result: R, transaction: Transaction) -> R:
|
462
|
+
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
463
|
+
if result_factory is None:
|
464
|
+
raise ValueError("Result factory is not set")
|
465
|
+
|
466
|
+
if self.task.cache_expiration is not None:
|
467
|
+
expiration = pendulum.now("utc") + self.task.cache_expiration
|
468
|
+
else:
|
469
|
+
expiration = None
|
470
|
+
|
471
|
+
terminal_state = run_coro_as_sync(
|
472
|
+
return_value_to_state(
|
473
|
+
result,
|
474
|
+
result_factory=result_factory,
|
475
|
+
key=transaction.key,
|
476
|
+
expiration=expiration,
|
477
|
+
# defer persistence to transaction commit
|
478
|
+
defer_persistence=True,
|
479
|
+
)
|
480
|
+
)
|
481
|
+
transaction.stage(
|
482
|
+
terminal_state.data,
|
483
|
+
on_rollback_hooks=[self.handle_rollback]
|
484
|
+
+ [
|
485
|
+
_with_transaction_hook_logging(hook, "rollback", self.logger)
|
486
|
+
for hook in self.task.on_rollback_hooks
|
487
|
+
],
|
488
|
+
on_commit_hooks=[
|
489
|
+
_with_transaction_hook_logging(hook, "commit", self.logger)
|
490
|
+
for hook in self.task.on_commit_hooks
|
491
|
+
],
|
492
|
+
)
|
493
|
+
if transaction.is_committed():
|
494
|
+
terminal_state.name = "Cached"
|
495
|
+
|
496
|
+
self.record_terminal_state_timing(terminal_state)
|
497
|
+
self.set_state(terminal_state)
|
498
|
+
self._return_value = result
|
499
|
+
return result
|
500
|
+
|
501
|
+
def handle_retry(self, exc: Exception) -> bool:
|
502
|
+
"""Handle any task run retries.
|
503
|
+
|
504
|
+
- If the task has retries left, and the retry condition is met, set the task to retrying and return True.
|
505
|
+
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
|
506
|
+
- If the task has no retries left, or the retry condition is not met, return False.
|
507
|
+
"""
|
508
|
+
if self.retries < self.task.retries and self.can_retry(exc):
|
509
|
+
if self.task.retry_delay_seconds:
|
510
|
+
delay = (
|
511
|
+
self.task.retry_delay_seconds[
|
512
|
+
min(self.retries, len(self.task.retry_delay_seconds) - 1)
|
513
|
+
] # repeat final delay value if attempts exceed specified delays
|
514
|
+
if isinstance(self.task.retry_delay_seconds, Sequence)
|
515
|
+
else self.task.retry_delay_seconds
|
516
|
+
)
|
517
|
+
new_state = AwaitingRetry(
|
518
|
+
scheduled_time=pendulum.now("utc").add(seconds=delay)
|
519
|
+
)
|
520
|
+
else:
|
521
|
+
delay = None
|
522
|
+
new_state = Retrying()
|
523
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
524
|
+
self.task_run.run_count += 1
|
525
|
+
|
526
|
+
self.logger.info(
|
527
|
+
"Task run failed with exception: %r - " "Retry %s/%s will start %s",
|
528
|
+
exc,
|
529
|
+
self.retries + 1,
|
530
|
+
self.task.retries,
|
531
|
+
str(delay) + " second(s) from now" if delay else "immediately",
|
532
|
+
)
|
533
|
+
|
534
|
+
self.set_state(new_state, force=True)
|
535
|
+
self.retries = self.retries + 1
|
536
|
+
return True
|
537
|
+
elif self.retries >= self.task.retries:
|
538
|
+
self.logger.error(
|
539
|
+
"Task run failed with exception: %r - Retries are exhausted",
|
540
|
+
exc,
|
541
|
+
exc_info=True,
|
542
|
+
)
|
543
|
+
return False
|
544
|
+
|
545
|
+
return False
|
546
|
+
|
547
|
+
def handle_exception(self, exc: Exception) -> None:
|
548
|
+
# If the task fails, and we have retries left, set the task to retrying.
|
549
|
+
if not self.handle_retry(exc):
|
550
|
+
# If the task has no retries left, or the retry condition is not met, set the task to failed.
|
551
|
+
context = TaskRunContext.get()
|
552
|
+
state = run_coro_as_sync(
|
553
|
+
exception_to_failed_state(
|
554
|
+
exc,
|
555
|
+
message="Task run encountered an exception",
|
556
|
+
result_factory=getattr(context, "result_factory", None),
|
557
|
+
)
|
558
|
+
)
|
559
|
+
self.record_terminal_state_timing(state)
|
560
|
+
self.set_state(state)
|
561
|
+
self._raised = exc
|
562
|
+
|
563
|
+
def handle_timeout(self, exc: TimeoutError) -> None:
|
564
|
+
if not self.handle_retry(exc):
|
565
|
+
if isinstance(exc, TaskRunTimeoutError):
|
566
|
+
message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
|
567
|
+
else:
|
568
|
+
message = f"Task run failed due to timeout: {exc!r}"
|
569
|
+
self.logger.error(message)
|
570
|
+
state = Failed(
|
571
|
+
data=exc,
|
572
|
+
message=message,
|
573
|
+
name="TimedOut",
|
574
|
+
)
|
575
|
+
self.set_state(state)
|
576
|
+
self._raised = exc
|
577
|
+
|
578
|
+
def handle_crash(self, exc: BaseException) -> None:
|
579
|
+
state = run_coro_as_sync(exception_to_crashed_state(exc))
|
580
|
+
self.logger.error(f"Crash detected! {state.message}")
|
581
|
+
self.logger.debug("Crash details:", exc_info=exc)
|
582
|
+
self.record_terminal_state_timing(state)
|
583
|
+
self.set_state(state, force=True)
|
584
|
+
self._raised = exc
|
585
|
+
|
586
|
+
@contextmanager
|
587
|
+
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
|
588
|
+
from prefect.utilities.engine import (
|
589
|
+
_resolve_custom_task_run_name,
|
590
|
+
should_log_prints,
|
591
|
+
)
|
592
|
+
|
593
|
+
if client is None:
|
594
|
+
client = self.client
|
595
|
+
if not self.task_run:
|
596
|
+
raise ValueError("Task run is not set")
|
597
|
+
|
598
|
+
if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
599
|
+
self.task_run = client.read_task_run(self.task_run.id)
|
600
|
+
with ExitStack() as stack:
|
601
|
+
if log_prints := should_log_prints(self.task):
|
602
|
+
stack.enter_context(patch_print())
|
603
|
+
stack.enter_context(
|
604
|
+
TaskRunContext(
|
605
|
+
task=self.task,
|
606
|
+
log_prints=log_prints,
|
607
|
+
task_run=self.task_run,
|
608
|
+
parameters=self.parameters,
|
609
|
+
result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore
|
610
|
+
client=client,
|
611
|
+
)
|
612
|
+
)
|
613
|
+
stack.enter_context(ConcurrencyContext())
|
614
|
+
|
615
|
+
self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
|
616
|
+
|
617
|
+
if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
618
|
+
# update the task run name if necessary
|
619
|
+
if not self._task_name_set and self.task.task_run_name:
|
620
|
+
task_run_name = _resolve_custom_task_run_name(
|
621
|
+
task=self.task, parameters=self.parameters
|
622
|
+
)
|
623
|
+
self.client.set_task_run_name(
|
624
|
+
task_run_id=self.task_run.id, name=task_run_name
|
625
|
+
)
|
626
|
+
self.logger.extra["task_run_name"] = task_run_name
|
627
|
+
self.logger.debug(
|
628
|
+
f"Renamed task run {self.task_run.name!r} to {task_run_name!r}"
|
629
|
+
)
|
630
|
+
self.task_run.name = task_run_name
|
631
|
+
self._task_name_set = True
|
632
|
+
yield
|
633
|
+
|
634
|
+
@contextmanager
|
635
|
+
def initialize_run(
|
636
|
+
self,
|
637
|
+
task_run_id: Optional[UUID] = None,
|
638
|
+
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
639
|
+
) -> Generator["SyncTaskRunEngine", Any, Any]:
|
640
|
+
"""
|
641
|
+
Enters a client context and creates a task run if needed.
|
642
|
+
"""
|
643
|
+
|
644
|
+
with hydrated_context(self.context):
|
645
|
+
with SyncClientContext.get_or_create() as client_ctx:
|
646
|
+
self._client = client_ctx.client
|
647
|
+
self._is_started = True
|
648
|
+
try:
|
649
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
650
|
+
from prefect.utilities.engine import (
|
651
|
+
_resolve_custom_task_run_name,
|
652
|
+
)
|
653
|
+
|
654
|
+
task_run_name = (
|
655
|
+
_resolve_custom_task_run_name(
|
656
|
+
task=self.task, parameters=self.parameters
|
657
|
+
)
|
658
|
+
if self.task.task_run_name
|
659
|
+
else None
|
660
|
+
)
|
661
|
+
|
662
|
+
if self.task_run and task_run_name:
|
663
|
+
self.task_run.name = task_run_name
|
664
|
+
|
665
|
+
if not self.task_run:
|
666
|
+
self.task_run = run_coro_as_sync(
|
667
|
+
self.task.create_local_run(
|
668
|
+
id=task_run_id,
|
669
|
+
parameters=self.parameters,
|
670
|
+
flow_run_context=FlowRunContext.get(),
|
671
|
+
parent_task_run_context=TaskRunContext.get(),
|
672
|
+
wait_for=self.wait_for,
|
673
|
+
extra_task_inputs=dependencies,
|
674
|
+
task_run_name=task_run_name,
|
675
|
+
)
|
676
|
+
)
|
677
|
+
else:
|
678
|
+
if not self.task_run:
|
679
|
+
self.task_run = run_coro_as_sync(
|
680
|
+
self.task.create_run(
|
681
|
+
id=task_run_id,
|
682
|
+
parameters=self.parameters,
|
683
|
+
flow_run_context=FlowRunContext.get(),
|
684
|
+
parent_task_run_context=TaskRunContext.get(),
|
685
|
+
wait_for=self.wait_for,
|
686
|
+
extra_task_inputs=dependencies,
|
687
|
+
)
|
688
|
+
)
|
689
|
+
# Emit an event to capture that the task run was in the `PENDING` state.
|
690
|
+
self._last_event = emit_task_run_state_change_event(
|
691
|
+
task_run=self.task_run,
|
692
|
+
initial_state=None,
|
693
|
+
validated_state=self.task_run.state,
|
694
|
+
)
|
695
|
+
|
696
|
+
with self.setup_run_context():
|
697
|
+
# setup_run_context might update the task run name, so log creation here
|
698
|
+
self.logger.info(
|
699
|
+
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
|
700
|
+
)
|
701
|
+
yield self
|
702
|
+
|
703
|
+
except TerminationSignal as exc:
|
704
|
+
# TerminationSignals are caught and handled as crashes
|
705
|
+
self.handle_crash(exc)
|
706
|
+
raise exc
|
707
|
+
|
708
|
+
except Exception:
|
709
|
+
# regular exceptions are caught and re-raised to the user
|
710
|
+
raise
|
711
|
+
except (Pause, Abort) as exc:
|
712
|
+
# Do not capture internal signals as crashes
|
713
|
+
if isinstance(exc, Abort):
|
714
|
+
self.logger.error("Task run was aborted: %s", exc)
|
715
|
+
raise
|
716
|
+
except GeneratorExit:
|
717
|
+
# Do not capture generator exits as crashes
|
718
|
+
raise
|
719
|
+
except BaseException as exc:
|
720
|
+
# BaseExceptions are caught and handled as crashes
|
721
|
+
self.handle_crash(exc)
|
722
|
+
raise
|
723
|
+
finally:
|
724
|
+
self.log_finished_message()
|
725
|
+
self._is_started = False
|
726
|
+
self._client = None
|
727
|
+
|
728
|
+
async def wait_until_ready(self):
|
729
|
+
"""Waits until the scheduled time (if its the future), then enters Running."""
|
730
|
+
if scheduled_time := self.state.state_details.scheduled_time:
|
731
|
+
sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
|
732
|
+
await anyio.sleep(sleep_time if sleep_time > 0 else 0)
|
733
|
+
self.set_state(
|
734
|
+
Retrying() if self.state.name == "AwaitingRetry" else Running(),
|
735
|
+
force=True,
|
736
|
+
)
|
737
|
+
|
738
|
+
# --------------------------
|
739
|
+
#
|
740
|
+
# The following methods compose the main task run loop
|
741
|
+
#
|
742
|
+
# --------------------------
|
743
|
+
|
744
|
+
@contextmanager
|
745
|
+
def start(
|
746
|
+
self,
|
747
|
+
task_run_id: Optional[UUID] = None,
|
748
|
+
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
749
|
+
) -> Generator[None, None, None]:
|
750
|
+
with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
|
751
|
+
self.begin_run()
|
752
|
+
try:
|
753
|
+
yield
|
754
|
+
finally:
|
755
|
+
self.call_hooks()
|
756
|
+
|
757
|
+
@contextmanager
|
758
|
+
def transaction_context(self) -> Generator[Transaction, None, None]:
|
759
|
+
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
760
|
+
|
761
|
+
# refresh cache setting is now repurposes as overwrite transaction record
|
762
|
+
overwrite = (
|
763
|
+
self.task.refresh_cache
|
764
|
+
if self.task.refresh_cache is not None
|
765
|
+
else PREFECT_TASKS_REFRESH_CACHE.value()
|
766
|
+
)
|
767
|
+
with transaction(
|
768
|
+
key=self.compute_transaction_key(),
|
769
|
+
store=ResultFactoryStore(result_factory=result_factory),
|
770
|
+
overwrite=overwrite,
|
771
|
+
logger=self.logger,
|
772
|
+
) as txn:
|
773
|
+
yield txn
|
774
|
+
|
775
|
+
@contextmanager
|
776
|
+
def run_context(self):
|
777
|
+
# reenter the run context to ensure it is up to date for every run
|
778
|
+
with self.setup_run_context():
|
779
|
+
try:
|
780
|
+
with timeout(
|
781
|
+
seconds=self.task.timeout_seconds,
|
782
|
+
timeout_exc_type=TaskRunTimeoutError,
|
783
|
+
):
|
784
|
+
self.logger.debug(
|
785
|
+
f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
|
786
|
+
)
|
787
|
+
if self.is_cancelled():
|
788
|
+
raise CancelledError("Task run cancelled by the task runner")
|
789
|
+
|
790
|
+
yield self
|
791
|
+
except TimeoutError as exc:
|
792
|
+
self.handle_timeout(exc)
|
793
|
+
except Exception as exc:
|
794
|
+
self.handle_exception(exc)
|
795
|
+
|
796
|
+
def call_task_fn(
|
797
|
+
self, transaction: Transaction
|
798
|
+
) -> Union[R, Coroutine[Any, Any, R]]:
|
799
|
+
"""
|
800
|
+
Convenience method to call the task function. Returns a coroutine if the
|
801
|
+
task is async.
|
802
|
+
"""
|
803
|
+
parameters = self.parameters or {}
|
804
|
+
if transaction.is_committed():
|
805
|
+
result = transaction.read()
|
806
|
+
else:
|
807
|
+
if (
|
808
|
+
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value()
|
809
|
+
and self.task.tags
|
810
|
+
):
|
811
|
+
# Acquire a concurrency slot for each tag, but only if a limit
|
812
|
+
# matching the tag already exists.
|
813
|
+
with concurrency(
|
814
|
+
list(self.task.tags), occupy=1, create_if_missing=False
|
815
|
+
):
|
816
|
+
result = call_with_parameters(self.task.fn, parameters)
|
817
|
+
else:
|
818
|
+
result = call_with_parameters(self.task.fn, parameters)
|
819
|
+
self.handle_success(result, transaction=transaction)
|
820
|
+
return result
|
821
|
+
|
112
822
|
|
113
|
-
|
114
|
-
|
115
|
-
|
823
|
+
@dataclass
|
824
|
+
class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
825
|
+
_client: Optional[PrefectClient] = None
|
116
826
|
|
117
827
|
@property
|
118
|
-
def client(self) ->
|
828
|
+
def client(self) -> PrefectClient:
|
119
829
|
if not self._is_started or self._client is None:
|
120
830
|
raise RuntimeError("Engine has not started.")
|
121
831
|
return self._client
|
122
832
|
|
123
|
-
|
124
|
-
def state(self) -> State:
|
125
|
-
if not self.task_run:
|
126
|
-
raise ValueError("Task run is not set")
|
127
|
-
return self.task_run.state
|
128
|
-
|
129
|
-
def can_retry(self, exc: Exception) -> bool:
|
833
|
+
async def can_retry(self, exc: Exception) -> bool:
|
130
834
|
retry_condition: Optional[
|
131
835
|
Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
|
132
836
|
] = self.task.retry_condition_fn
|
@@ -142,14 +846,13 @@ class TaskRunEngine(Generic[P, R]):
|
|
142
846
|
message=f"Task run encountered unexpected exception: {repr(exc)}",
|
143
847
|
)
|
144
848
|
if inspect.iscoroutinefunction(retry_condition):
|
145
|
-
should_retry =
|
146
|
-
retry_condition(self.task, self.task_run, state)
|
147
|
-
)
|
849
|
+
should_retry = await retry_condition(self.task, self.task_run, state)
|
148
850
|
elif inspect.isfunction(retry_condition):
|
149
851
|
should_retry = retry_condition(self.task, self.task_run, state)
|
150
852
|
else:
|
151
853
|
should_retry = not retry_condition
|
152
854
|
return should_retry
|
855
|
+
|
153
856
|
except Exception:
|
154
857
|
self.logger.error(
|
155
858
|
(
|
@@ -160,16 +863,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
160
863
|
)
|
161
864
|
return False
|
162
865
|
|
163
|
-
def
|
164
|
-
if (
|
165
|
-
self.context
|
166
|
-
and "cancel_event" in self.context
|
167
|
-
and isinstance(self.context["cancel_event"], threading.Event)
|
168
|
-
):
|
169
|
-
return self.context["cancel_event"].is_set()
|
170
|
-
return False
|
171
|
-
|
172
|
-
def call_hooks(self, state: Optional[State] = None):
|
866
|
+
async def call_hooks(self, state: Optional[State] = None):
|
173
867
|
if state is None:
|
174
868
|
state = self.state
|
175
869
|
task = self.task
|
@@ -195,7 +889,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
195
889
|
)
|
196
890
|
result = hook(task, task_run, state)
|
197
891
|
if inspect.isawaitable(result):
|
198
|
-
|
892
|
+
await result
|
199
893
|
except Exception:
|
200
894
|
self.logger.error(
|
201
895
|
f"An error was encountered while running hook {hook_name!r}",
|
@@ -204,71 +898,12 @@ class TaskRunEngine(Generic[P, R]):
|
|
204
898
|
else:
|
205
899
|
self.logger.info(f"Hook {hook_name!r} finished running successfully")
|
206
900
|
|
207
|
-
def
|
208
|
-
key = None
|
209
|
-
if self.task.cache_policy:
|
210
|
-
flow_run_context = FlowRunContext.get()
|
211
|
-
task_run_context = TaskRunContext.get()
|
212
|
-
|
213
|
-
if flow_run_context:
|
214
|
-
parameters = flow_run_context.parameters
|
215
|
-
else:
|
216
|
-
parameters = None
|
217
|
-
|
218
|
-
key = self.task.cache_policy.compute_key(
|
219
|
-
task_ctx=task_run_context,
|
220
|
-
inputs=self.parameters,
|
221
|
-
flow_parameters=parameters,
|
222
|
-
)
|
223
|
-
elif self.task.result_storage_key is not None:
|
224
|
-
key = _format_user_supplied_storage_key(self.task.result_storage_key)
|
225
|
-
return key
|
226
|
-
|
227
|
-
def _resolve_parameters(self):
|
228
|
-
if not self.parameters:
|
229
|
-
return {}
|
230
|
-
|
231
|
-
resolved_parameters = {}
|
232
|
-
for parameter, value in self.parameters.items():
|
233
|
-
try:
|
234
|
-
resolved_parameters[parameter] = visit_collection(
|
235
|
-
value,
|
236
|
-
visit_fn=resolve_to_final_result,
|
237
|
-
return_data=True,
|
238
|
-
max_depth=-1,
|
239
|
-
remove_annotations=True,
|
240
|
-
context={},
|
241
|
-
)
|
242
|
-
except UpstreamTaskError:
|
243
|
-
raise
|
244
|
-
except Exception as exc:
|
245
|
-
raise PrefectException(
|
246
|
-
f"Failed to resolve inputs in parameter {parameter!r}. If your"
|
247
|
-
" parameter type is not supported, consider using the `quote`"
|
248
|
-
" annotation to skip resolution of inputs."
|
249
|
-
) from exc
|
250
|
-
|
251
|
-
self.parameters = resolved_parameters
|
252
|
-
|
253
|
-
def _wait_for_dependencies(self):
|
254
|
-
if not self.wait_for:
|
255
|
-
return
|
256
|
-
|
257
|
-
visit_collection(
|
258
|
-
self.wait_for,
|
259
|
-
visit_fn=resolve_to_final_result,
|
260
|
-
return_data=False,
|
261
|
-
max_depth=-1,
|
262
|
-
remove_annotations=True,
|
263
|
-
context={"current_task_run": self.task_run, "current_task": self.task},
|
264
|
-
)
|
265
|
-
|
266
|
-
def begin_run(self):
|
901
|
+
async def begin_run(self):
|
267
902
|
try:
|
268
903
|
self._resolve_parameters()
|
269
904
|
self._wait_for_dependencies()
|
270
905
|
except UpstreamTaskError as upstream_exc:
|
271
|
-
state = self.set_state(
|
906
|
+
state = await self.set_state(
|
272
907
|
Pending(
|
273
908
|
name="NotReady",
|
274
909
|
message=str(upstream_exc),
|
@@ -291,7 +926,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
291
926
|
flow_run = flow_run_context.flow_run
|
292
927
|
self.task_run.flow_run_run_count = flow_run.run_count
|
293
928
|
|
294
|
-
state = self.set_state(new_state)
|
929
|
+
state = await self.set_state(new_state)
|
295
930
|
|
296
931
|
# TODO: this is temporary until the API stops rejecting state transitions
|
297
932
|
# and the client / transaction store becomes the source of truth
|
@@ -299,11 +934,10 @@ class TaskRunEngine(Generic[P, R]):
|
|
299
934
|
# result reference that no longer exists
|
300
935
|
if state.is_completed():
|
301
936
|
try:
|
302
|
-
state.result(retry_result_failure=False
|
937
|
+
await state.result(retry_result_failure=False)
|
303
938
|
except Exception:
|
304
|
-
state = self.set_state(new_state, force=True)
|
939
|
+
state = await self.set_state(new_state, force=True)
|
305
940
|
|
306
|
-
BACKOFF_MAX = 10
|
307
941
|
backoff_count = 0
|
308
942
|
|
309
943
|
# TODO: Could this listen for state change events instead of polling?
|
@@ -313,10 +947,10 @@ class TaskRunEngine(Generic[P, R]):
|
|
313
947
|
interval = clamped_poisson_interval(
|
314
948
|
average_interval=backoff_count, clamping_factor=0.3
|
315
949
|
)
|
316
|
-
|
317
|
-
state = self.set_state(new_state)
|
950
|
+
await anyio.sleep(interval)
|
951
|
+
state = await self.set_state(new_state)
|
318
952
|
|
319
|
-
def set_state(self, state: State, force: bool = False) -> State:
|
953
|
+
async def set_state(self, state: State, force: bool = False) -> State:
|
320
954
|
last_state = self.state
|
321
955
|
if not self.task_run:
|
322
956
|
raise ValueError("Task run is not set")
|
@@ -334,7 +968,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
334
968
|
self.task_run.state_name = new_state.name
|
335
969
|
else:
|
336
970
|
try:
|
337
|
-
new_state =
|
971
|
+
new_state = await propose_state(
|
338
972
|
self.client, state, task_run_id=self.task_run.id, force=force
|
339
973
|
)
|
340
974
|
except Pause as exc:
|
@@ -361,14 +995,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
361
995
|
|
362
996
|
return new_state
|
363
997
|
|
364
|
-
def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
|
998
|
+
async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
|
365
999
|
if self._return_value is not NotSet:
|
366
1000
|
# if the return value is a BaseResult, we need to fetch it
|
367
1001
|
if isinstance(self._return_value, BaseResult):
|
368
|
-
|
369
|
-
if inspect.isawaitable(_result):
|
370
|
-
_result = run_coro_as_sync(_result)
|
371
|
-
return _result
|
1002
|
+
return await self._return_value.get()
|
372
1003
|
|
373
1004
|
# otherwise, return the value as is
|
374
1005
|
return self._return_value
|
@@ -381,7 +1012,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
381
1012
|
# otherwise, return the exception
|
382
1013
|
return self._raised
|
383
1014
|
|
384
|
-
def handle_success(self, result: R, transaction: Transaction) -> R:
|
1015
|
+
async def handle_success(self, result: R, transaction: Transaction) -> R:
|
385
1016
|
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
386
1017
|
if result_factory is None:
|
387
1018
|
raise ValueError("Result factory is not set")
|
@@ -391,19 +1022,18 @@ class TaskRunEngine(Generic[P, R]):
|
|
391
1022
|
else:
|
392
1023
|
expiration = None
|
393
1024
|
|
394
|
-
terminal_state =
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
defer_persistence=True,
|
402
|
-
)
|
1025
|
+
terminal_state = await return_value_to_state(
|
1026
|
+
result,
|
1027
|
+
result_factory=result_factory,
|
1028
|
+
key=transaction.key,
|
1029
|
+
expiration=expiration,
|
1030
|
+
# defer persistence to transaction commit
|
1031
|
+
defer_persistence=True,
|
403
1032
|
)
|
404
1033
|
transaction.stage(
|
405
1034
|
terminal_state.data,
|
406
|
-
on_rollback_hooks=[
|
1035
|
+
on_rollback_hooks=[self.handle_rollback]
|
1036
|
+
+ [
|
407
1037
|
_with_transaction_hook_logging(hook, "rollback", self.logger)
|
408
1038
|
for hook in self.task.on_rollback_hooks
|
409
1039
|
],
|
@@ -416,18 +1046,18 @@ class TaskRunEngine(Generic[P, R]):
|
|
416
1046
|
terminal_state.name = "Cached"
|
417
1047
|
|
418
1048
|
self.record_terminal_state_timing(terminal_state)
|
419
|
-
self.set_state(terminal_state)
|
1049
|
+
await self.set_state(terminal_state)
|
420
1050
|
self._return_value = result
|
421
1051
|
return result
|
422
1052
|
|
423
|
-
def handle_retry(self, exc: Exception) -> bool:
|
1053
|
+
async def handle_retry(self, exc: Exception) -> bool:
|
424
1054
|
"""Handle any task run retries.
|
425
1055
|
|
426
1056
|
- If the task has retries left, and the retry condition is met, set the task to retrying and return True.
|
427
|
-
|
1057
|
+
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
|
428
1058
|
- If the task has no retries left, or the retry condition is not met, return False.
|
429
1059
|
"""
|
430
|
-
if self.retries < self.task.retries and self.can_retry(exc):
|
1060
|
+
if self.retries < self.task.retries and await self.can_retry(exc):
|
431
1061
|
if self.task.retry_delay_seconds:
|
432
1062
|
delay = (
|
433
1063
|
self.task.retry_delay_seconds[
|
@@ -453,7 +1083,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
453
1083
|
str(delay) + " second(s) from now" if delay else "immediately",
|
454
1084
|
)
|
455
1085
|
|
456
|
-
self.set_state(new_state, force=True)
|
1086
|
+
await self.set_state(new_state, force=True)
|
457
1087
|
self.retries = self.retries + 1
|
458
1088
|
return True
|
459
1089
|
elif self.retries >= self.task.retries:
|
@@ -466,24 +1096,22 @@ class TaskRunEngine(Generic[P, R]):
|
|
466
1096
|
|
467
1097
|
return False
|
468
1098
|
|
469
|
-
def handle_exception(self, exc: Exception) -> None:
|
1099
|
+
async def handle_exception(self, exc: Exception) -> None:
|
470
1100
|
# If the task fails, and we have retries left, set the task to retrying.
|
471
|
-
if not self.handle_retry(exc):
|
1101
|
+
if not await self.handle_retry(exc):
|
472
1102
|
# If the task has no retries left, or the retry condition is not met, set the task to failed.
|
473
1103
|
context = TaskRunContext.get()
|
474
|
-
state =
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
result_factory=getattr(context, "result_factory", None),
|
479
|
-
)
|
1104
|
+
state = await exception_to_failed_state(
|
1105
|
+
exc,
|
1106
|
+
message="Task run encountered an exception",
|
1107
|
+
result_factory=getattr(context, "result_factory", None),
|
480
1108
|
)
|
481
1109
|
self.record_terminal_state_timing(state)
|
482
|
-
self.set_state(state)
|
1110
|
+
await self.set_state(state)
|
483
1111
|
self._raised = exc
|
484
1112
|
|
485
|
-
def handle_timeout(self, exc: TimeoutError) -> None:
|
486
|
-
if not self.handle_retry(exc):
|
1113
|
+
async def handle_timeout(self, exc: TimeoutError) -> None:
|
1114
|
+
if not await self.handle_retry(exc):
|
487
1115
|
if isinstance(exc, TaskRunTimeoutError):
|
488
1116
|
message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
|
489
1117
|
else:
|
@@ -494,29 +1122,19 @@ class TaskRunEngine(Generic[P, R]):
|
|
494
1122
|
message=message,
|
495
1123
|
name="TimedOut",
|
496
1124
|
)
|
497
|
-
self.set_state(state)
|
1125
|
+
await self.set_state(state)
|
498
1126
|
self._raised = exc
|
499
1127
|
|
500
|
-
def handle_crash(self, exc: BaseException) -> None:
|
501
|
-
state =
|
1128
|
+
async def handle_crash(self, exc: BaseException) -> None:
|
1129
|
+
state = await exception_to_crashed_state(exc)
|
502
1130
|
self.logger.error(f"Crash detected! {state.message}")
|
503
1131
|
self.logger.debug("Crash details:", exc_info=exc)
|
504
1132
|
self.record_terminal_state_timing(state)
|
505
|
-
self.set_state(state, force=True)
|
1133
|
+
await self.set_state(state, force=True)
|
506
1134
|
self._raised = exc
|
507
1135
|
|
508
|
-
|
509
|
-
|
510
|
-
if self.task_run.start_time and not self.task_run.end_time:
|
511
|
-
self.task_run.end_time = state.timestamp
|
512
|
-
|
513
|
-
if self.task_run.state.is_running():
|
514
|
-
self.task_run.total_run_time += (
|
515
|
-
state.timestamp - self.task_run.state.timestamp
|
516
|
-
)
|
517
|
-
|
518
|
-
@contextmanager
|
519
|
-
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
|
1136
|
+
@asynccontextmanager
|
1137
|
+
async def setup_run_context(self, client: Optional[PrefectClient] = None):
|
520
1138
|
from prefect.utilities.engine import (
|
521
1139
|
_resolve_custom_task_run_name,
|
522
1140
|
should_log_prints,
|
@@ -528,7 +1146,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
528
1146
|
raise ValueError("Task run is not set")
|
529
1147
|
|
530
1148
|
if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
531
|
-
self.task_run = client.read_task_run(self.task_run.id)
|
1149
|
+
self.task_run = await client.read_task_run(self.task_run.id)
|
532
1150
|
with ExitStack() as stack:
|
533
1151
|
if log_prints := should_log_prints(self.task):
|
534
1152
|
stack.enter_context(patch_print())
|
@@ -538,10 +1156,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
538
1156
|
log_prints=log_prints,
|
539
1157
|
task_run=self.task_run,
|
540
1158
|
parameters=self.parameters,
|
541
|
-
result_factory=
|
1159
|
+
result_factory=await ResultFactory.from_task(self.task), # type: ignore
|
542
1160
|
client=client,
|
543
1161
|
)
|
544
1162
|
)
|
1163
|
+
stack.enter_context(ConcurrencyContext())
|
545
1164
|
|
546
1165
|
self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
|
547
1166
|
|
@@ -551,7 +1170,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
551
1170
|
task_run_name = _resolve_custom_task_run_name(
|
552
1171
|
task=self.task, parameters=self.parameters
|
553
1172
|
)
|
554
|
-
self.client.set_task_run_name(
|
1173
|
+
await self.client.set_task_run_name(
|
555
1174
|
task_run_id=self.task_run.id, name=task_run_name
|
556
1175
|
)
|
557
1176
|
self.logger.extra["task_run_name"] = task_run_name
|
@@ -562,55 +1181,56 @@ class TaskRunEngine(Generic[P, R]):
|
|
562
1181
|
self._task_name_set = True
|
563
1182
|
yield
|
564
1183
|
|
565
|
-
@
|
566
|
-
def initialize_run(
|
1184
|
+
@asynccontextmanager
|
1185
|
+
async def initialize_run(
|
567
1186
|
self,
|
568
1187
|
task_run_id: Optional[UUID] = None,
|
569
1188
|
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
570
|
-
) ->
|
1189
|
+
) -> AsyncGenerator["AsyncTaskRunEngine", Any]:
|
571
1190
|
"""
|
572
1191
|
Enters a client context and creates a task run if needed.
|
573
1192
|
"""
|
574
1193
|
|
575
1194
|
with hydrated_context(self.context):
|
576
|
-
with
|
577
|
-
self._client =
|
1195
|
+
async with AsyncClientContext.get_or_create():
|
1196
|
+
self._client = get_client()
|
578
1197
|
self._is_started = True
|
579
1198
|
try:
|
580
|
-
if
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
_resolve_custom_task_run_name,
|
585
|
-
)
|
1199
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
1200
|
+
from prefect.utilities.engine import (
|
1201
|
+
_resolve_custom_task_run_name,
|
1202
|
+
)
|
586
1203
|
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
1204
|
+
task_run_name = (
|
1205
|
+
_resolve_custom_task_run_name(
|
1206
|
+
task=self.task, parameters=self.parameters
|
1207
|
+
)
|
1208
|
+
if self.task.task_run_name
|
1209
|
+
else None
|
1210
|
+
)
|
592
1211
|
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
1212
|
+
if self.task_run and task_run_name:
|
1213
|
+
self.task_run.name = task_run_name
|
1214
|
+
|
1215
|
+
if not self.task_run:
|
1216
|
+
self.task_run = await self.task.create_local_run(
|
1217
|
+
id=task_run_id,
|
1218
|
+
parameters=self.parameters,
|
1219
|
+
flow_run_context=FlowRunContext.get(),
|
1220
|
+
parent_task_run_context=TaskRunContext.get(),
|
1221
|
+
wait_for=self.wait_for,
|
1222
|
+
extra_task_inputs=dependencies,
|
1223
|
+
task_run_name=task_run_name,
|
603
1224
|
)
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
)
|
1225
|
+
else:
|
1226
|
+
if not self.task_run:
|
1227
|
+
self.task_run = await self.task.create_run(
|
1228
|
+
id=task_run_id,
|
1229
|
+
parameters=self.parameters,
|
1230
|
+
flow_run_context=FlowRunContext.get(),
|
1231
|
+
parent_task_run_context=TaskRunContext.get(),
|
1232
|
+
wait_for=self.wait_for,
|
1233
|
+
extra_task_inputs=dependencies,
|
614
1234
|
)
|
615
1235
|
# Emit an event to capture that the task run was in the `PENDING` state.
|
616
1236
|
self._last_event = emit_task_run_state_change_event(
|
@@ -619,7 +1239,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
619
1239
|
validated_state=self.task_run.state,
|
620
1240
|
)
|
621
1241
|
|
622
|
-
with self.setup_run_context():
|
1242
|
+
async with self.setup_run_context():
|
623
1243
|
# setup_run_context might update the task run name, so log creation here
|
624
1244
|
self.logger.info(
|
625
1245
|
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
|
@@ -628,7 +1248,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
628
1248
|
|
629
1249
|
except TerminationSignal as exc:
|
630
1250
|
# TerminationSignals are caught and handled as crashes
|
631
|
-
self.handle_crash(exc)
|
1251
|
+
await self.handle_crash(exc)
|
632
1252
|
raise exc
|
633
1253
|
|
634
1254
|
except Exception:
|
@@ -644,60 +1264,19 @@ class TaskRunEngine(Generic[P, R]):
|
|
644
1264
|
raise
|
645
1265
|
except BaseException as exc:
|
646
1266
|
# BaseExceptions are caught and handled as crashes
|
647
|
-
self.handle_crash(exc)
|
1267
|
+
await self.handle_crash(exc)
|
648
1268
|
raise
|
649
1269
|
finally:
|
650
|
-
|
651
|
-
display_state = (
|
652
|
-
repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
|
653
|
-
)
|
654
|
-
level = logging.INFO if self.state.is_completed() else logging.ERROR
|
655
|
-
msg = f"Finished in state {display_state}"
|
656
|
-
if self.state.is_pending():
|
657
|
-
msg += (
|
658
|
-
"\nPlease wait for all submitted tasks to complete"
|
659
|
-
" before exiting your flow by calling `.wait()` on the "
|
660
|
-
"`PrefectFuture` returned from your `.submit()` calls."
|
661
|
-
)
|
662
|
-
msg += dedent(
|
663
|
-
"""
|
664
|
-
|
665
|
-
Example:
|
666
|
-
|
667
|
-
from prefect import flow, task
|
668
|
-
|
669
|
-
@task
|
670
|
-
def say_hello(name):
|
671
|
-
print f"Hello, {name}!"
|
672
|
-
|
673
|
-
@flow
|
674
|
-
def example_flow():
|
675
|
-
future = say_hello.submit(name="Marvin)
|
676
|
-
future.wait()
|
677
|
-
|
678
|
-
example_flow()
|
679
|
-
"""
|
680
|
-
)
|
681
|
-
self.logger.log(
|
682
|
-
level=level,
|
683
|
-
msg=msg,
|
684
|
-
)
|
685
|
-
|
1270
|
+
self.log_finished_message()
|
686
1271
|
self._is_started = False
|
687
1272
|
self._client = None
|
688
1273
|
|
689
|
-
def is_running(self) -> bool:
|
690
|
-
"""Whether or not the engine is currently running a task."""
|
691
|
-
if (task_run := getattr(self, "task_run", None)) is None:
|
692
|
-
return False
|
693
|
-
return task_run.state.is_running() or task_run.state.is_scheduled()
|
694
|
-
|
695
1274
|
async def wait_until_ready(self):
|
696
1275
|
"""Waits until the scheduled time (if its the future), then enters Running."""
|
697
1276
|
if scheduled_time := self.state.state_details.scheduled_time:
|
698
1277
|
sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
|
699
1278
|
await anyio.sleep(sleep_time if sleep_time > 0 else 0)
|
700
|
-
self.set_state(
|
1279
|
+
await self.set_state(
|
701
1280
|
Retrying() if self.state.name == "AwaitingRetry" else Running(),
|
702
1281
|
force=True,
|
703
1282
|
)
|
@@ -708,21 +1287,23 @@ class TaskRunEngine(Generic[P, R]):
|
|
708
1287
|
#
|
709
1288
|
# --------------------------
|
710
1289
|
|
711
|
-
@
|
712
|
-
def start(
|
1290
|
+
@asynccontextmanager
|
1291
|
+
async def start(
|
713
1292
|
self,
|
714
1293
|
task_run_id: Optional[UUID] = None,
|
715
1294
|
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
716
|
-
) ->
|
717
|
-
with self.initialize_run(
|
718
|
-
|
1295
|
+
) -> AsyncGenerator[None, None]:
|
1296
|
+
async with self.initialize_run(
|
1297
|
+
task_run_id=task_run_id, dependencies=dependencies
|
1298
|
+
):
|
1299
|
+
await self.begin_run()
|
719
1300
|
try:
|
720
1301
|
yield
|
721
1302
|
finally:
|
722
|
-
self.call_hooks()
|
1303
|
+
await self.call_hooks()
|
723
1304
|
|
724
|
-
@
|
725
|
-
def transaction_context(self) ->
|
1305
|
+
@asynccontextmanager
|
1306
|
+
async def transaction_context(self) -> AsyncGenerator[Transaction, None]:
|
726
1307
|
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
727
1308
|
|
728
1309
|
# refresh cache setting is now repurposes as overwrite transaction record
|
@@ -739,13 +1320,12 @@ class TaskRunEngine(Generic[P, R]):
|
|
739
1320
|
) as txn:
|
740
1321
|
yield txn
|
741
1322
|
|
742
|
-
@
|
743
|
-
def run_context(self):
|
744
|
-
timeout_context = timeout_async if self.task.isasync else timeout
|
1323
|
+
@asynccontextmanager
|
1324
|
+
async def run_context(self):
|
745
1325
|
# reenter the run context to ensure it is up to date for every run
|
746
|
-
with self.setup_run_context():
|
1326
|
+
async with self.setup_run_context():
|
747
1327
|
try:
|
748
|
-
with
|
1328
|
+
with timeout_async(
|
749
1329
|
seconds=self.task.timeout_seconds,
|
750
1330
|
timeout_exc_type=TaskRunTimeoutError,
|
751
1331
|
):
|
@@ -757,11 +1337,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
757
1337
|
|
758
1338
|
yield self
|
759
1339
|
except TimeoutError as exc:
|
760
|
-
self.handle_timeout(exc)
|
1340
|
+
await self.handle_timeout(exc)
|
761
1341
|
except Exception as exc:
|
762
|
-
self.handle_exception(exc)
|
1342
|
+
await self.handle_exception(exc)
|
763
1343
|
|
764
|
-
def call_task_fn(
|
1344
|
+
async def call_task_fn(
|
765
1345
|
self, transaction: Transaction
|
766
1346
|
) -> Union[R, Coroutine[Any, Any, R]]:
|
767
1347
|
"""
|
@@ -769,24 +1349,23 @@ class TaskRunEngine(Generic[P, R]):
|
|
769
1349
|
task is async.
|
770
1350
|
"""
|
771
1351
|
parameters = self.parameters or {}
|
772
|
-
if
|
773
|
-
|
774
|
-
async def _call_task_fn():
|
775
|
-
if transaction.is_committed():
|
776
|
-
result = transaction.read()
|
777
|
-
else:
|
778
|
-
result = await call_with_parameters(self.task.fn, parameters)
|
779
|
-
self.handle_success(result, transaction=transaction)
|
780
|
-
return result
|
781
|
-
|
782
|
-
return _call_task_fn()
|
1352
|
+
if transaction.is_committed():
|
1353
|
+
result = transaction.read()
|
783
1354
|
else:
|
784
|
-
if
|
785
|
-
|
1355
|
+
if (
|
1356
|
+
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value()
|
1357
|
+
and self.task.tags
|
1358
|
+
):
|
1359
|
+
# Acquire a concurrency slot for each tag, but only if a limit
|
1360
|
+
# matching the tag already exists.
|
1361
|
+
async with aconcurrency(
|
1362
|
+
list(self.task.tags), occupy=1, create_if_missing=False
|
1363
|
+
):
|
1364
|
+
result = await call_with_parameters(self.task.fn, parameters)
|
786
1365
|
else:
|
787
|
-
result = call_with_parameters(self.task.fn, parameters)
|
788
|
-
|
789
|
-
|
1366
|
+
result = await call_with_parameters(self.task.fn, parameters)
|
1367
|
+
await self.handle_success(result, transaction=transaction)
|
1368
|
+
return result
|
790
1369
|
|
791
1370
|
|
792
1371
|
def run_task_sync(
|
@@ -799,7 +1378,7 @@ def run_task_sync(
|
|
799
1378
|
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
800
1379
|
context: Optional[Dict[str, Any]] = None,
|
801
1380
|
) -> Union[R, State, None]:
|
802
|
-
engine =
|
1381
|
+
engine = SyncTaskRunEngine[P, R](
|
803
1382
|
task=task,
|
804
1383
|
parameters=parameters,
|
805
1384
|
task_run=task_run,
|
@@ -826,7 +1405,7 @@ async def run_task_async(
|
|
826
1405
|
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
827
1406
|
context: Optional[Dict[str, Any]] = None,
|
828
1407
|
) -> Union[R, State, None]:
|
829
|
-
engine =
|
1408
|
+
engine = AsyncTaskRunEngine[P, R](
|
830
1409
|
task=task,
|
831
1410
|
parameters=parameters,
|
832
1411
|
task_run=task_run,
|
@@ -834,13 +1413,13 @@ async def run_task_async(
|
|
834
1413
|
context=context,
|
835
1414
|
)
|
836
1415
|
|
837
|
-
with engine.start(task_run_id=task_run_id, dependencies=dependencies):
|
1416
|
+
async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
|
838
1417
|
while engine.is_running():
|
839
1418
|
await engine.wait_until_ready()
|
840
|
-
with engine.run_context(), engine.transaction_context() as txn:
|
1419
|
+
async with engine.run_context(), engine.transaction_context() as txn:
|
841
1420
|
await engine.call_task_fn(txn)
|
842
1421
|
|
843
|
-
return engine.state if return_type == "state" else engine.result()
|
1422
|
+
return engine.state if return_type == "state" else await engine.result()
|
844
1423
|
|
845
1424
|
|
846
1425
|
def run_generator_task_sync(
|
@@ -856,7 +1435,7 @@ def run_generator_task_sync(
|
|
856
1435
|
if return_type != "result":
|
857
1436
|
raise ValueError("The return_type for a generator task must be 'result'")
|
858
1437
|
|
859
|
-
engine =
|
1438
|
+
engine = SyncTaskRunEngine[P, R](
|
860
1439
|
task=task,
|
861
1440
|
parameters=parameters,
|
862
1441
|
task_run=task_run,
|
@@ -910,7 +1489,7 @@ async def run_generator_task_async(
|
|
910
1489
|
) -> AsyncGenerator[R, None]:
|
911
1490
|
if return_type != "result":
|
912
1491
|
raise ValueError("The return_type for a generator task must be 'result'")
|
913
|
-
engine =
|
1492
|
+
engine = AsyncTaskRunEngine[P, R](
|
914
1493
|
task=task,
|
915
1494
|
parameters=parameters,
|
916
1495
|
task_run=task_run,
|
@@ -918,10 +1497,10 @@ async def run_generator_task_async(
|
|
918
1497
|
context=context,
|
919
1498
|
)
|
920
1499
|
|
921
|
-
with engine.start(task_run_id=task_run_id, dependencies=dependencies):
|
1500
|
+
async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
|
922
1501
|
while engine.is_running():
|
923
1502
|
await engine.wait_until_ready()
|
924
|
-
with engine.run_context(), engine.transaction_context() as txn:
|
1503
|
+
async with engine.run_context(), engine.transaction_context() as txn:
|
925
1504
|
# TODO: generators should default to commit_mode=OFF
|
926
1505
|
# because they are dynamic by definition
|
927
1506
|
# for now we just prevent this branch explicitly
|
@@ -945,13 +1524,13 @@ async def run_generator_task_async(
|
|
945
1524
|
link_state_to_result(engine.state, gen_result)
|
946
1525
|
yield gen_result
|
947
1526
|
except (StopAsyncIteration, GeneratorExit) as exc:
|
948
|
-
engine.handle_success(None, transaction=txn)
|
1527
|
+
await engine.handle_success(None, transaction=txn)
|
949
1528
|
if isinstance(exc, GeneratorExit):
|
950
1529
|
gen.throw(exc)
|
951
1530
|
|
952
1531
|
# async generators can't return, but we can raise failures here
|
953
1532
|
if engine.state.is_failed():
|
954
|
-
engine.result()
|
1533
|
+
await engine.result()
|
955
1534
|
|
956
1535
|
|
957
1536
|
def run_task(
|