prefect-client 3.0.0rc13__py3-none-any.whl → 3.0.0rc15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/_internal/compatibility/deprecated.py +0 -53
- prefect/blocks/core.py +132 -4
- prefect/blocks/notifications.py +26 -3
- prefect/client/base.py +30 -24
- prefect/client/orchestration.py +121 -47
- prefect/client/utilities.py +4 -4
- prefect/concurrency/asyncio.py +48 -7
- prefect/concurrency/context.py +24 -0
- prefect/concurrency/services.py +24 -8
- prefect/concurrency/sync.py +30 -3
- prefect/context.py +85 -24
- prefect/events/clients.py +93 -60
- prefect/events/utilities.py +0 -2
- prefect/events/worker.py +9 -2
- prefect/flow_engine.py +6 -3
- prefect/flows.py +176 -12
- prefect/futures.py +84 -7
- prefect/profiles.toml +16 -2
- prefect/runner/runner.py +6 -1
- prefect/runner/storage.py +4 -0
- prefect/settings.py +108 -14
- prefect/task_engine.py +901 -285
- prefect/task_runs.py +24 -1
- prefect/task_worker.py +7 -1
- prefect/tasks.py +9 -5
- prefect/utilities/asyncutils.py +0 -6
- prefect/utilities/callables.py +5 -3
- prefect/utilities/engine.py +3 -0
- prefect/utilities/importtools.py +138 -58
- prefect/utilities/schema_tools/validation.py +30 -0
- prefect/utilities/services.py +32 -0
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/METADATA +39 -39
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/RECORD +36 -35
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/WHEEL +1 -1
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/LICENSE +0 -0
- {prefect_client-3.0.0rc13.dist-info → prefect_client-3.0.0rc15.dist-info}/top_level.txt +0 -0
prefect/task_engine.py
CHANGED
@@ -3,7 +3,7 @@ import logging
|
|
3
3
|
import threading
|
4
4
|
import time
|
5
5
|
from asyncio import CancelledError
|
6
|
-
from contextlib import ExitStack, contextmanager
|
6
|
+
from contextlib import ExitStack, asynccontextmanager, contextmanager
|
7
7
|
from dataclasses import dataclass, field
|
8
8
|
from functools import wraps
|
9
9
|
from textwrap import dedent
|
@@ -31,12 +31,16 @@ import pendulum
|
|
31
31
|
from typing_extensions import ParamSpec
|
32
32
|
|
33
33
|
from prefect import Task
|
34
|
-
from prefect.client.orchestration import SyncPrefectClient
|
34
|
+
from prefect.client.orchestration import PrefectClient, SyncPrefectClient, get_client
|
35
35
|
from prefect.client.schemas import TaskRun
|
36
36
|
from prefect.client.schemas.objects import State, TaskRunInput
|
37
|
+
from prefect.concurrency.asyncio import concurrency as aconcurrency
|
38
|
+
from prefect.concurrency.context import ConcurrencyContext
|
39
|
+
from prefect.concurrency.sync import concurrency
|
37
40
|
from prefect.context import (
|
38
|
-
|
41
|
+
AsyncClientContext,
|
39
42
|
FlowRunContext,
|
43
|
+
SyncClientContext,
|
40
44
|
TaskRunContext,
|
41
45
|
hydrated_context,
|
42
46
|
)
|
@@ -59,6 +63,7 @@ from prefect.settings import (
|
|
59
63
|
)
|
60
64
|
from prefect.states import (
|
61
65
|
AwaitingRetry,
|
66
|
+
Completed,
|
62
67
|
Failed,
|
63
68
|
Paused,
|
64
69
|
Pending,
|
@@ -77,6 +82,7 @@ from prefect.utilities.engine import (
|
|
77
82
|
_get_hook_name,
|
78
83
|
emit_task_run_state_change_event,
|
79
84
|
link_state_to_result,
|
85
|
+
propose_state,
|
80
86
|
propose_state_sync,
|
81
87
|
resolve_to_final_result,
|
82
88
|
)
|
@@ -86,47 +92,767 @@ from prefect.utilities.timeout import timeout, timeout_async
|
|
86
92
|
P = ParamSpec("P")
|
87
93
|
R = TypeVar("R")
|
88
94
|
|
95
|
+
BACKOFF_MAX = 10
|
96
|
+
|
89
97
|
|
90
98
|
class TaskRunTimeoutError(TimeoutError):
|
91
99
|
"""Raised when a task run exceeds its timeout."""
|
92
100
|
|
93
101
|
|
94
|
-
@dataclass
|
95
|
-
class
|
96
|
-
task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
|
97
|
-
logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
|
98
|
-
parameters: Optional[Dict[str, Any]] = None
|
99
|
-
task_run: Optional[TaskRun] = None
|
100
|
-
retries: int = 0
|
101
|
-
wait_for: Optional[Iterable[PrefectFuture]] = None
|
102
|
-
context: Optional[Dict[str, Any]] = None
|
103
|
-
# holds the return value from the user code
|
104
|
-
_return_value: Union[R, Type[NotSet]] = NotSet
|
105
|
-
# holds the exception raised by the user code, if any
|
106
|
-
_raised: Union[Exception, Type[NotSet]] = NotSet
|
107
|
-
_initial_run_context: Optional[TaskRunContext] = None
|
108
|
-
_is_started: bool = False
|
109
|
-
|
110
|
-
|
111
|
-
|
102
|
+
@dataclass
|
103
|
+
class BaseTaskRunEngine(Generic[P, R]):
|
104
|
+
task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
|
105
|
+
logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
|
106
|
+
parameters: Optional[Dict[str, Any]] = None
|
107
|
+
task_run: Optional[TaskRun] = None
|
108
|
+
retries: int = 0
|
109
|
+
wait_for: Optional[Iterable[PrefectFuture]] = None
|
110
|
+
context: Optional[Dict[str, Any]] = None
|
111
|
+
# holds the return value from the user code
|
112
|
+
_return_value: Union[R, Type[NotSet]] = NotSet
|
113
|
+
# holds the exception raised by the user code, if any
|
114
|
+
_raised: Union[Exception, Type[NotSet]] = NotSet
|
115
|
+
_initial_run_context: Optional[TaskRunContext] = None
|
116
|
+
_is_started: bool = False
|
117
|
+
_task_name_set: bool = False
|
118
|
+
_last_event: Optional[PrefectEvent] = None
|
119
|
+
|
120
|
+
def __post_init__(self):
|
121
|
+
if self.parameters is None:
|
122
|
+
self.parameters = {}
|
123
|
+
|
124
|
+
@property
|
125
|
+
def state(self) -> State:
|
126
|
+
if not self.task_run:
|
127
|
+
raise ValueError("Task run is not set")
|
128
|
+
return self.task_run.state
|
129
|
+
|
130
|
+
def is_cancelled(self) -> bool:
|
131
|
+
if (
|
132
|
+
self.context
|
133
|
+
and "cancel_event" in self.context
|
134
|
+
and isinstance(self.context["cancel_event"], threading.Event)
|
135
|
+
):
|
136
|
+
return self.context["cancel_event"].is_set()
|
137
|
+
return False
|
138
|
+
|
139
|
+
def compute_transaction_key(self) -> Optional[str]:
|
140
|
+
key = None
|
141
|
+
if self.task.cache_policy:
|
142
|
+
flow_run_context = FlowRunContext.get()
|
143
|
+
task_run_context = TaskRunContext.get()
|
144
|
+
|
145
|
+
if flow_run_context:
|
146
|
+
parameters = flow_run_context.parameters
|
147
|
+
else:
|
148
|
+
parameters = None
|
149
|
+
|
150
|
+
key = self.task.cache_policy.compute_key(
|
151
|
+
task_ctx=task_run_context,
|
152
|
+
inputs=self.parameters,
|
153
|
+
flow_parameters=parameters,
|
154
|
+
)
|
155
|
+
elif self.task.result_storage_key is not None:
|
156
|
+
key = _format_user_supplied_storage_key(self.task.result_storage_key)
|
157
|
+
return key
|
158
|
+
|
159
|
+
def _resolve_parameters(self):
|
160
|
+
if not self.parameters:
|
161
|
+
return {}
|
162
|
+
|
163
|
+
resolved_parameters = {}
|
164
|
+
for parameter, value in self.parameters.items():
|
165
|
+
try:
|
166
|
+
resolved_parameters[parameter] = visit_collection(
|
167
|
+
value,
|
168
|
+
visit_fn=resolve_to_final_result,
|
169
|
+
return_data=True,
|
170
|
+
max_depth=-1,
|
171
|
+
remove_annotations=True,
|
172
|
+
context={},
|
173
|
+
)
|
174
|
+
except UpstreamTaskError:
|
175
|
+
raise
|
176
|
+
except Exception as exc:
|
177
|
+
raise PrefectException(
|
178
|
+
f"Failed to resolve inputs in parameter {parameter!r}. If your"
|
179
|
+
" parameter type is not supported, consider using the `quote`"
|
180
|
+
" annotation to skip resolution of inputs."
|
181
|
+
) from exc
|
182
|
+
|
183
|
+
self.parameters = resolved_parameters
|
184
|
+
|
185
|
+
def _wait_for_dependencies(self):
|
186
|
+
if not self.wait_for:
|
187
|
+
return
|
188
|
+
|
189
|
+
visit_collection(
|
190
|
+
self.wait_for,
|
191
|
+
visit_fn=resolve_to_final_result,
|
192
|
+
return_data=False,
|
193
|
+
max_depth=-1,
|
194
|
+
remove_annotations=True,
|
195
|
+
context={"current_task_run": self.task_run, "current_task": self.task},
|
196
|
+
)
|
197
|
+
|
198
|
+
def record_terminal_state_timing(self, state: State) -> None:
|
199
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
200
|
+
if self.task_run.start_time and not self.task_run.end_time:
|
201
|
+
self.task_run.end_time = state.timestamp
|
202
|
+
|
203
|
+
if self.task_run.state.is_running():
|
204
|
+
self.task_run.total_run_time += (
|
205
|
+
state.timestamp - self.task_run.state.timestamp
|
206
|
+
)
|
207
|
+
|
208
|
+
def is_running(self) -> bool:
|
209
|
+
"""Whether or not the engine is currently running a task."""
|
210
|
+
if (task_run := getattr(self, "task_run", None)) is None:
|
211
|
+
return False
|
212
|
+
return task_run.state.is_running() or task_run.state.is_scheduled()
|
213
|
+
|
214
|
+
def log_finished_message(self):
|
215
|
+
# If debugging, use the more complete `repr` than the usual `str` description
|
216
|
+
display_state = repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
|
217
|
+
level = logging.INFO if self.state.is_completed() else logging.ERROR
|
218
|
+
msg = f"Finished in state {display_state}"
|
219
|
+
if self.state.is_pending():
|
220
|
+
msg += (
|
221
|
+
"\nPlease wait for all submitted tasks to complete"
|
222
|
+
" before exiting your flow by calling `.wait()` on the "
|
223
|
+
"`PrefectFuture` returned from your `.submit()` calls."
|
224
|
+
)
|
225
|
+
msg += dedent(
|
226
|
+
"""
|
227
|
+
|
228
|
+
Example:
|
229
|
+
|
230
|
+
from prefect import flow, task
|
231
|
+
|
232
|
+
@task
|
233
|
+
def say_hello(name):
|
234
|
+
print f"Hello, {name}!"
|
235
|
+
|
236
|
+
@flow
|
237
|
+
def example_flow():
|
238
|
+
future = say_hello.submit(name="Marvin)
|
239
|
+
future.wait()
|
240
|
+
|
241
|
+
example_flow()
|
242
|
+
"""
|
243
|
+
)
|
244
|
+
self.logger.log(
|
245
|
+
level=level,
|
246
|
+
msg=msg,
|
247
|
+
)
|
248
|
+
|
249
|
+
def handle_rollback(self, txn: Transaction) -> None:
|
250
|
+
assert self.task_run is not None
|
251
|
+
|
252
|
+
rolled_back_state = Completed(
|
253
|
+
name="RolledBack",
|
254
|
+
message="Task rolled back as part of transaction",
|
255
|
+
)
|
256
|
+
|
257
|
+
self._last_event = emit_task_run_state_change_event(
|
258
|
+
task_run=self.task_run,
|
259
|
+
initial_state=self.state,
|
260
|
+
validated_state=rolled_back_state,
|
261
|
+
follows=self._last_event,
|
262
|
+
)
|
263
|
+
|
264
|
+
|
265
|
+
@dataclass
|
266
|
+
class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
267
|
+
_client: Optional[SyncPrefectClient] = None
|
268
|
+
|
269
|
+
@property
|
270
|
+
def client(self) -> SyncPrefectClient:
|
271
|
+
if not self._is_started or self._client is None:
|
272
|
+
raise RuntimeError("Engine has not started.")
|
273
|
+
return self._client
|
274
|
+
|
275
|
+
def can_retry(self, exc: Exception) -> bool:
|
276
|
+
retry_condition: Optional[
|
277
|
+
Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
|
278
|
+
] = self.task.retry_condition_fn
|
279
|
+
if not self.task_run:
|
280
|
+
raise ValueError("Task run is not set")
|
281
|
+
try:
|
282
|
+
self.logger.debug(
|
283
|
+
f"Running `retry_condition_fn` check {retry_condition!r} for task"
|
284
|
+
f" {self.task.name!r}"
|
285
|
+
)
|
286
|
+
state = Failed(
|
287
|
+
data=exc,
|
288
|
+
message=f"Task run encountered unexpected exception: {repr(exc)}",
|
289
|
+
)
|
290
|
+
if inspect.iscoroutinefunction(retry_condition):
|
291
|
+
should_retry = run_coro_as_sync(
|
292
|
+
retry_condition(self.task, self.task_run, state)
|
293
|
+
)
|
294
|
+
elif inspect.isfunction(retry_condition):
|
295
|
+
should_retry = retry_condition(self.task, self.task_run, state)
|
296
|
+
else:
|
297
|
+
should_retry = not retry_condition
|
298
|
+
return should_retry
|
299
|
+
except Exception:
|
300
|
+
self.logger.error(
|
301
|
+
(
|
302
|
+
"An error was encountered while running `retry_condition_fn` check"
|
303
|
+
f" '{retry_condition!r}' for task {self.task.name!r}"
|
304
|
+
),
|
305
|
+
exc_info=True,
|
306
|
+
)
|
307
|
+
return False
|
308
|
+
|
309
|
+
def call_hooks(self, state: Optional[State] = None):
|
310
|
+
if state is None:
|
311
|
+
state = self.state
|
312
|
+
task = self.task
|
313
|
+
task_run = self.task_run
|
314
|
+
|
315
|
+
if not task_run:
|
316
|
+
raise ValueError("Task run is not set")
|
317
|
+
|
318
|
+
if state.is_failed() and task.on_failure_hooks:
|
319
|
+
hooks = task.on_failure_hooks
|
320
|
+
elif state.is_completed() and task.on_completion_hooks:
|
321
|
+
hooks = task.on_completion_hooks
|
322
|
+
else:
|
323
|
+
hooks = None
|
324
|
+
|
325
|
+
for hook in hooks or []:
|
326
|
+
hook_name = _get_hook_name(hook)
|
327
|
+
|
328
|
+
try:
|
329
|
+
self.logger.info(
|
330
|
+
f"Running hook {hook_name!r} in response to entering state"
|
331
|
+
f" {state.name!r}"
|
332
|
+
)
|
333
|
+
result = hook(task, task_run, state)
|
334
|
+
if inspect.isawaitable(result):
|
335
|
+
run_coro_as_sync(result)
|
336
|
+
except Exception:
|
337
|
+
self.logger.error(
|
338
|
+
f"An error was encountered while running hook {hook_name!r}",
|
339
|
+
exc_info=True,
|
340
|
+
)
|
341
|
+
else:
|
342
|
+
self.logger.info(f"Hook {hook_name!r} finished running successfully")
|
343
|
+
|
344
|
+
def begin_run(self):
|
345
|
+
try:
|
346
|
+
self._resolve_parameters()
|
347
|
+
self._wait_for_dependencies()
|
348
|
+
except UpstreamTaskError as upstream_exc:
|
349
|
+
state = self.set_state(
|
350
|
+
Pending(
|
351
|
+
name="NotReady",
|
352
|
+
message=str(upstream_exc),
|
353
|
+
),
|
354
|
+
# if orchestrating a run already in a pending state, force orchestration to
|
355
|
+
# update the state name
|
356
|
+
force=self.state.is_pending(),
|
357
|
+
)
|
358
|
+
return
|
359
|
+
|
360
|
+
new_state = Running()
|
361
|
+
|
362
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
363
|
+
self.task_run.start_time = new_state.timestamp
|
364
|
+
self.task_run.run_count += 1
|
365
|
+
|
366
|
+
flow_run_context = FlowRunContext.get()
|
367
|
+
if flow_run_context:
|
368
|
+
# Carry forward any task run information from the flow run
|
369
|
+
flow_run = flow_run_context.flow_run
|
370
|
+
self.task_run.flow_run_run_count = flow_run.run_count
|
371
|
+
|
372
|
+
state = self.set_state(new_state)
|
373
|
+
|
374
|
+
# TODO: this is temporary until the API stops rejecting state transitions
|
375
|
+
# and the client / transaction store becomes the source of truth
|
376
|
+
# this is a bandaid caused by the API storing a Completed state with a bad
|
377
|
+
# result reference that no longer exists
|
378
|
+
if state.is_completed():
|
379
|
+
try:
|
380
|
+
state.result(retry_result_failure=False, _sync=True)
|
381
|
+
except Exception:
|
382
|
+
state = self.set_state(new_state, force=True)
|
383
|
+
|
384
|
+
backoff_count = 0
|
385
|
+
|
386
|
+
# TODO: Could this listen for state change events instead of polling?
|
387
|
+
while state.is_pending() or state.is_paused():
|
388
|
+
if backoff_count < BACKOFF_MAX:
|
389
|
+
backoff_count += 1
|
390
|
+
interval = clamped_poisson_interval(
|
391
|
+
average_interval=backoff_count, clamping_factor=0.3
|
392
|
+
)
|
393
|
+
time.sleep(interval)
|
394
|
+
state = self.set_state(new_state)
|
395
|
+
|
396
|
+
def set_state(self, state: State, force: bool = False) -> State:
|
397
|
+
last_state = self.state
|
398
|
+
if not self.task_run:
|
399
|
+
raise ValueError("Task run is not set")
|
400
|
+
|
401
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
402
|
+
self.task_run.state = new_state = state
|
403
|
+
|
404
|
+
# Ensure that the state_details are populated with the current run IDs
|
405
|
+
new_state.state_details.task_run_id = self.task_run.id
|
406
|
+
new_state.state_details.flow_run_id = self.task_run.flow_run_id
|
407
|
+
|
408
|
+
# Predictively update the de-normalized task_run.state_* attributes
|
409
|
+
self.task_run.state_id = new_state.id
|
410
|
+
self.task_run.state_type = new_state.type
|
411
|
+
self.task_run.state_name = new_state.name
|
412
|
+
|
413
|
+
if new_state.is_final():
|
414
|
+
if (
|
415
|
+
isinstance(state.data, BaseResult)
|
416
|
+
and state.data.has_cached_object()
|
417
|
+
):
|
418
|
+
# Avoid fetching the result unless it is cached, otherwise we defeat
|
419
|
+
# the purpose of disabling `cache_result_in_memory`
|
420
|
+
result = state.result(raise_on_failure=False, fetch=True)
|
421
|
+
if inspect.isawaitable(result):
|
422
|
+
result = run_coro_as_sync(result)
|
423
|
+
else:
|
424
|
+
result = state.data
|
425
|
+
|
426
|
+
link_state_to_result(state, result)
|
427
|
+
|
428
|
+
else:
|
429
|
+
try:
|
430
|
+
new_state = propose_state_sync(
|
431
|
+
self.client, state, task_run_id=self.task_run.id, force=force
|
432
|
+
)
|
433
|
+
except Pause as exc:
|
434
|
+
# We shouldn't get a pause signal without a state, but if this happens,
|
435
|
+
# just use a Paused state to assume an in-process pause.
|
436
|
+
new_state = exc.state if exc.state else Paused()
|
437
|
+
if new_state.state_details.pause_reschedule:
|
438
|
+
# If we're being asked to pause and reschedule, we should exit the
|
439
|
+
# task and expect to be resumed later.
|
440
|
+
raise
|
441
|
+
|
442
|
+
# currently this is a hack to keep a reference to the state object
|
443
|
+
# that has an in-memory result attached to it; using the API state
|
444
|
+
# could result in losing that reference
|
445
|
+
self.task_run.state = new_state
|
446
|
+
|
447
|
+
# emit a state change event
|
448
|
+
self._last_event = emit_task_run_state_change_event(
|
449
|
+
task_run=self.task_run,
|
450
|
+
initial_state=last_state,
|
451
|
+
validated_state=self.task_run.state,
|
452
|
+
follows=self._last_event,
|
453
|
+
)
|
454
|
+
|
455
|
+
return new_state
|
456
|
+
|
457
|
+
def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
|
458
|
+
if self._return_value is not NotSet:
|
459
|
+
# if the return value is a BaseResult, we need to fetch it
|
460
|
+
if isinstance(self._return_value, BaseResult):
|
461
|
+
_result = self._return_value.get()
|
462
|
+
if inspect.isawaitable(_result):
|
463
|
+
_result = run_coro_as_sync(_result)
|
464
|
+
return _result
|
465
|
+
|
466
|
+
# otherwise, return the value as is
|
467
|
+
return self._return_value
|
468
|
+
|
469
|
+
if self._raised is not NotSet:
|
470
|
+
# if the task raised an exception, raise it
|
471
|
+
if raise_on_failure:
|
472
|
+
raise self._raised
|
473
|
+
|
474
|
+
# otherwise, return the exception
|
475
|
+
return self._raised
|
476
|
+
|
477
|
+
def handle_success(self, result: R, transaction: Transaction) -> R:
|
478
|
+
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
479
|
+
if result_factory is None:
|
480
|
+
raise ValueError("Result factory is not set")
|
481
|
+
|
482
|
+
if self.task.cache_expiration is not None:
|
483
|
+
expiration = pendulum.now("utc") + self.task.cache_expiration
|
484
|
+
else:
|
485
|
+
expiration = None
|
486
|
+
|
487
|
+
terminal_state = run_coro_as_sync(
|
488
|
+
return_value_to_state(
|
489
|
+
result,
|
490
|
+
result_factory=result_factory,
|
491
|
+
key=transaction.key,
|
492
|
+
expiration=expiration,
|
493
|
+
# defer persistence to transaction commit
|
494
|
+
defer_persistence=True,
|
495
|
+
)
|
496
|
+
)
|
497
|
+
transaction.stage(
|
498
|
+
terminal_state.data,
|
499
|
+
on_rollback_hooks=[self.handle_rollback]
|
500
|
+
+ [
|
501
|
+
_with_transaction_hook_logging(hook, "rollback", self.logger)
|
502
|
+
for hook in self.task.on_rollback_hooks
|
503
|
+
],
|
504
|
+
on_commit_hooks=[
|
505
|
+
_with_transaction_hook_logging(hook, "commit", self.logger)
|
506
|
+
for hook in self.task.on_commit_hooks
|
507
|
+
],
|
508
|
+
)
|
509
|
+
if transaction.is_committed():
|
510
|
+
terminal_state.name = "Cached"
|
511
|
+
|
512
|
+
self.record_terminal_state_timing(terminal_state)
|
513
|
+
self.set_state(terminal_state)
|
514
|
+
self._return_value = result
|
515
|
+
return result
|
516
|
+
|
517
|
+
def handle_retry(self, exc: Exception) -> bool:
|
518
|
+
"""Handle any task run retries.
|
519
|
+
|
520
|
+
- If the task has retries left, and the retry condition is met, set the task to retrying and return True.
|
521
|
+
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
|
522
|
+
- If the task has no retries left, or the retry condition is not met, return False.
|
523
|
+
"""
|
524
|
+
if self.retries < self.task.retries and self.can_retry(exc):
|
525
|
+
if self.task.retry_delay_seconds:
|
526
|
+
delay = (
|
527
|
+
self.task.retry_delay_seconds[
|
528
|
+
min(self.retries, len(self.task.retry_delay_seconds) - 1)
|
529
|
+
] # repeat final delay value if attempts exceed specified delays
|
530
|
+
if isinstance(self.task.retry_delay_seconds, Sequence)
|
531
|
+
else self.task.retry_delay_seconds
|
532
|
+
)
|
533
|
+
new_state = AwaitingRetry(
|
534
|
+
scheduled_time=pendulum.now("utc").add(seconds=delay)
|
535
|
+
)
|
536
|
+
else:
|
537
|
+
delay = None
|
538
|
+
new_state = Retrying()
|
539
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
540
|
+
self.task_run.run_count += 1
|
541
|
+
|
542
|
+
self.logger.info(
|
543
|
+
"Task run failed with exception: %r - " "Retry %s/%s will start %s",
|
544
|
+
exc,
|
545
|
+
self.retries + 1,
|
546
|
+
self.task.retries,
|
547
|
+
str(delay) + " second(s) from now" if delay else "immediately",
|
548
|
+
)
|
549
|
+
|
550
|
+
self.set_state(new_state, force=True)
|
551
|
+
self.retries = self.retries + 1
|
552
|
+
return True
|
553
|
+
elif self.retries >= self.task.retries:
|
554
|
+
self.logger.error(
|
555
|
+
"Task run failed with exception: %r - Retries are exhausted",
|
556
|
+
exc,
|
557
|
+
exc_info=True,
|
558
|
+
)
|
559
|
+
return False
|
560
|
+
|
561
|
+
return False
|
562
|
+
|
563
|
+
def handle_exception(self, exc: Exception) -> None:
|
564
|
+
# If the task fails, and we have retries left, set the task to retrying.
|
565
|
+
if not self.handle_retry(exc):
|
566
|
+
# If the task has no retries left, or the retry condition is not met, set the task to failed.
|
567
|
+
context = TaskRunContext.get()
|
568
|
+
state = run_coro_as_sync(
|
569
|
+
exception_to_failed_state(
|
570
|
+
exc,
|
571
|
+
message="Task run encountered an exception",
|
572
|
+
result_factory=getattr(context, "result_factory", None),
|
573
|
+
)
|
574
|
+
)
|
575
|
+
self.record_terminal_state_timing(state)
|
576
|
+
self.set_state(state)
|
577
|
+
self._raised = exc
|
578
|
+
|
579
|
+
def handle_timeout(self, exc: TimeoutError) -> None:
|
580
|
+
if not self.handle_retry(exc):
|
581
|
+
if isinstance(exc, TaskRunTimeoutError):
|
582
|
+
message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
|
583
|
+
else:
|
584
|
+
message = f"Task run failed due to timeout: {exc!r}"
|
585
|
+
self.logger.error(message)
|
586
|
+
state = Failed(
|
587
|
+
data=exc,
|
588
|
+
message=message,
|
589
|
+
name="TimedOut",
|
590
|
+
)
|
591
|
+
self.set_state(state)
|
592
|
+
self._raised = exc
|
593
|
+
|
594
|
+
def handle_crash(self, exc: BaseException) -> None:
|
595
|
+
state = run_coro_as_sync(exception_to_crashed_state(exc))
|
596
|
+
self.logger.error(f"Crash detected! {state.message}")
|
597
|
+
self.logger.debug("Crash details:", exc_info=exc)
|
598
|
+
self.record_terminal_state_timing(state)
|
599
|
+
self.set_state(state, force=True)
|
600
|
+
self._raised = exc
|
601
|
+
|
602
|
+
@contextmanager
|
603
|
+
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
|
604
|
+
from prefect.utilities.engine import (
|
605
|
+
_resolve_custom_task_run_name,
|
606
|
+
should_log_prints,
|
607
|
+
)
|
608
|
+
|
609
|
+
if client is None:
|
610
|
+
client = self.client
|
611
|
+
if not self.task_run:
|
612
|
+
raise ValueError("Task run is not set")
|
613
|
+
|
614
|
+
if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
615
|
+
self.task_run = client.read_task_run(self.task_run.id)
|
616
|
+
with ExitStack() as stack:
|
617
|
+
if log_prints := should_log_prints(self.task):
|
618
|
+
stack.enter_context(patch_print())
|
619
|
+
stack.enter_context(
|
620
|
+
TaskRunContext(
|
621
|
+
task=self.task,
|
622
|
+
log_prints=log_prints,
|
623
|
+
task_run=self.task_run,
|
624
|
+
parameters=self.parameters,
|
625
|
+
result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore
|
626
|
+
client=client,
|
627
|
+
)
|
628
|
+
)
|
629
|
+
stack.enter_context(ConcurrencyContext())
|
630
|
+
|
631
|
+
self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
|
632
|
+
|
633
|
+
if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
634
|
+
# update the task run name if necessary
|
635
|
+
if not self._task_name_set and self.task.task_run_name:
|
636
|
+
task_run_name = _resolve_custom_task_run_name(
|
637
|
+
task=self.task, parameters=self.parameters
|
638
|
+
)
|
639
|
+
self.client.set_task_run_name(
|
640
|
+
task_run_id=self.task_run.id, name=task_run_name
|
641
|
+
)
|
642
|
+
self.logger.extra["task_run_name"] = task_run_name
|
643
|
+
self.logger.debug(
|
644
|
+
f"Renamed task run {self.task_run.name!r} to {task_run_name!r}"
|
645
|
+
)
|
646
|
+
self.task_run.name = task_run_name
|
647
|
+
self._task_name_set = True
|
648
|
+
yield
|
649
|
+
|
650
|
+
@contextmanager
|
651
|
+
def initialize_run(
|
652
|
+
self,
|
653
|
+
task_run_id: Optional[UUID] = None,
|
654
|
+
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
655
|
+
) -> Generator["SyncTaskRunEngine", Any, Any]:
|
656
|
+
"""
|
657
|
+
Enters a client context and creates a task run if needed.
|
658
|
+
"""
|
659
|
+
|
660
|
+
with hydrated_context(self.context):
|
661
|
+
with SyncClientContext.get_or_create() as client_ctx:
|
662
|
+
self._client = client_ctx.client
|
663
|
+
self._is_started = True
|
664
|
+
try:
|
665
|
+
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
666
|
+
from prefect.utilities.engine import (
|
667
|
+
_resolve_custom_task_run_name,
|
668
|
+
)
|
669
|
+
|
670
|
+
task_run_name = (
|
671
|
+
_resolve_custom_task_run_name(
|
672
|
+
task=self.task, parameters=self.parameters
|
673
|
+
)
|
674
|
+
if self.task.task_run_name
|
675
|
+
else None
|
676
|
+
)
|
677
|
+
|
678
|
+
if self.task_run and task_run_name:
|
679
|
+
self.task_run.name = task_run_name
|
680
|
+
|
681
|
+
if not self.task_run:
|
682
|
+
self.task_run = run_coro_as_sync(
|
683
|
+
self.task.create_local_run(
|
684
|
+
id=task_run_id,
|
685
|
+
parameters=self.parameters,
|
686
|
+
flow_run_context=FlowRunContext.get(),
|
687
|
+
parent_task_run_context=TaskRunContext.get(),
|
688
|
+
wait_for=self.wait_for,
|
689
|
+
extra_task_inputs=dependencies,
|
690
|
+
task_run_name=task_run_name,
|
691
|
+
)
|
692
|
+
)
|
693
|
+
# Emit an event to capture that the task run was in the `PENDING` state.
|
694
|
+
self._last_event = emit_task_run_state_change_event(
|
695
|
+
task_run=self.task_run,
|
696
|
+
initial_state=None,
|
697
|
+
validated_state=self.task_run.state,
|
698
|
+
)
|
699
|
+
else:
|
700
|
+
if not self.task_run:
|
701
|
+
self.task_run = run_coro_as_sync(
|
702
|
+
self.task.create_run(
|
703
|
+
id=task_run_id,
|
704
|
+
parameters=self.parameters,
|
705
|
+
flow_run_context=FlowRunContext.get(),
|
706
|
+
parent_task_run_context=TaskRunContext.get(),
|
707
|
+
wait_for=self.wait_for,
|
708
|
+
extra_task_inputs=dependencies,
|
709
|
+
)
|
710
|
+
)
|
711
|
+
# Emit an event to capture that the task run was in the `PENDING` state.
|
712
|
+
self._last_event = emit_task_run_state_change_event(
|
713
|
+
task_run=self.task_run,
|
714
|
+
initial_state=None,
|
715
|
+
validated_state=self.task_run.state,
|
716
|
+
)
|
717
|
+
|
718
|
+
with self.setup_run_context():
|
719
|
+
# setup_run_context might update the task run name, so log creation here
|
720
|
+
self.logger.info(
|
721
|
+
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
|
722
|
+
)
|
723
|
+
yield self
|
724
|
+
|
725
|
+
except TerminationSignal as exc:
|
726
|
+
# TerminationSignals are caught and handled as crashes
|
727
|
+
self.handle_crash(exc)
|
728
|
+
raise exc
|
729
|
+
|
730
|
+
except Exception:
|
731
|
+
# regular exceptions are caught and re-raised to the user
|
732
|
+
raise
|
733
|
+
except (Pause, Abort) as exc:
|
734
|
+
# Do not capture internal signals as crashes
|
735
|
+
if isinstance(exc, Abort):
|
736
|
+
self.logger.error("Task run was aborted: %s", exc)
|
737
|
+
raise
|
738
|
+
except GeneratorExit:
|
739
|
+
# Do not capture generator exits as crashes
|
740
|
+
raise
|
741
|
+
except BaseException as exc:
|
742
|
+
# BaseExceptions are caught and handled as crashes
|
743
|
+
self.handle_crash(exc)
|
744
|
+
raise
|
745
|
+
finally:
|
746
|
+
self.log_finished_message()
|
747
|
+
self._is_started = False
|
748
|
+
self._client = None
|
749
|
+
|
750
|
+
async def wait_until_ready(self):
|
751
|
+
"""Waits until the scheduled time (if its the future), then enters Running."""
|
752
|
+
if scheduled_time := self.state.state_details.scheduled_time:
|
753
|
+
sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
|
754
|
+
await anyio.sleep(sleep_time if sleep_time > 0 else 0)
|
755
|
+
self.set_state(
|
756
|
+
Retrying() if self.state.name == "AwaitingRetry" else Running(),
|
757
|
+
force=True,
|
758
|
+
)
|
759
|
+
|
760
|
+
# --------------------------
|
761
|
+
#
|
762
|
+
# The following methods compose the main task run loop
|
763
|
+
#
|
764
|
+
# --------------------------
|
765
|
+
|
766
|
+
@contextmanager
|
767
|
+
def start(
|
768
|
+
self,
|
769
|
+
task_run_id: Optional[UUID] = None,
|
770
|
+
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
771
|
+
) -> Generator[None, None, None]:
|
772
|
+
with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
|
773
|
+
self.begin_run()
|
774
|
+
try:
|
775
|
+
yield
|
776
|
+
finally:
|
777
|
+
self.call_hooks()
|
778
|
+
|
779
|
+
@contextmanager
|
780
|
+
def transaction_context(self) -> Generator[Transaction, None, None]:
|
781
|
+
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
782
|
+
|
783
|
+
# refresh cache setting is now repurposes as overwrite transaction record
|
784
|
+
overwrite = (
|
785
|
+
self.task.refresh_cache
|
786
|
+
if self.task.refresh_cache is not None
|
787
|
+
else PREFECT_TASKS_REFRESH_CACHE.value()
|
788
|
+
)
|
789
|
+
with transaction(
|
790
|
+
key=self.compute_transaction_key(),
|
791
|
+
store=ResultFactoryStore(result_factory=result_factory),
|
792
|
+
overwrite=overwrite,
|
793
|
+
logger=self.logger,
|
794
|
+
) as txn:
|
795
|
+
yield txn
|
796
|
+
|
797
|
+
@contextmanager
|
798
|
+
def run_context(self):
|
799
|
+
# reenter the run context to ensure it is up to date for every run
|
800
|
+
with self.setup_run_context():
|
801
|
+
try:
|
802
|
+
with timeout(
|
803
|
+
seconds=self.task.timeout_seconds,
|
804
|
+
timeout_exc_type=TaskRunTimeoutError,
|
805
|
+
):
|
806
|
+
self.logger.debug(
|
807
|
+
f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
|
808
|
+
)
|
809
|
+
if self.is_cancelled():
|
810
|
+
raise CancelledError("Task run cancelled by the task runner")
|
811
|
+
|
812
|
+
yield self
|
813
|
+
except TimeoutError as exc:
|
814
|
+
self.handle_timeout(exc)
|
815
|
+
except Exception as exc:
|
816
|
+
self.handle_exception(exc)
|
817
|
+
|
818
|
+
def call_task_fn(
|
819
|
+
self, transaction: Transaction
|
820
|
+
) -> Union[R, Coroutine[Any, Any, R]]:
|
821
|
+
"""
|
822
|
+
Convenience method to call the task function. Returns a coroutine if the
|
823
|
+
task is async.
|
824
|
+
"""
|
825
|
+
parameters = self.parameters or {}
|
826
|
+
if transaction.is_committed():
|
827
|
+
result = transaction.read()
|
828
|
+
else:
|
829
|
+
if (
|
830
|
+
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value()
|
831
|
+
and self.task.tags
|
832
|
+
):
|
833
|
+
# Acquire a concurrency slot for each tag, but only if a limit
|
834
|
+
# matching the tag already exists.
|
835
|
+
with concurrency(
|
836
|
+
list(self.task.tags), occupy=1, create_if_missing=False
|
837
|
+
):
|
838
|
+
result = call_with_parameters(self.task.fn, parameters)
|
839
|
+
else:
|
840
|
+
result = call_with_parameters(self.task.fn, parameters)
|
841
|
+
self.handle_success(result, transaction=transaction)
|
842
|
+
return result
|
843
|
+
|
112
844
|
|
113
|
-
|
114
|
-
|
115
|
-
|
845
|
+
@dataclass
|
846
|
+
class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
|
847
|
+
_client: Optional[PrefectClient] = None
|
116
848
|
|
117
849
|
@property
|
118
|
-
def client(self) ->
|
850
|
+
def client(self) -> PrefectClient:
|
119
851
|
if not self._is_started or self._client is None:
|
120
852
|
raise RuntimeError("Engine has not started.")
|
121
853
|
return self._client
|
122
854
|
|
123
|
-
|
124
|
-
def state(self) -> State:
|
125
|
-
if not self.task_run:
|
126
|
-
raise ValueError("Task run is not set")
|
127
|
-
return self.task_run.state
|
128
|
-
|
129
|
-
def can_retry(self, exc: Exception) -> bool:
|
855
|
+
async def can_retry(self, exc: Exception) -> bool:
|
130
856
|
retry_condition: Optional[
|
131
857
|
Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
|
132
858
|
] = self.task.retry_condition_fn
|
@@ -142,14 +868,13 @@ class TaskRunEngine(Generic[P, R]):
|
|
142
868
|
message=f"Task run encountered unexpected exception: {repr(exc)}",
|
143
869
|
)
|
144
870
|
if inspect.iscoroutinefunction(retry_condition):
|
145
|
-
should_retry =
|
146
|
-
retry_condition(self.task, self.task_run, state)
|
147
|
-
)
|
871
|
+
should_retry = await retry_condition(self.task, self.task_run, state)
|
148
872
|
elif inspect.isfunction(retry_condition):
|
149
873
|
should_retry = retry_condition(self.task, self.task_run, state)
|
150
874
|
else:
|
151
875
|
should_retry = not retry_condition
|
152
876
|
return should_retry
|
877
|
+
|
153
878
|
except Exception:
|
154
879
|
self.logger.error(
|
155
880
|
(
|
@@ -160,16 +885,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
160
885
|
)
|
161
886
|
return False
|
162
887
|
|
163
|
-
def
|
164
|
-
if (
|
165
|
-
self.context
|
166
|
-
and "cancel_event" in self.context
|
167
|
-
and isinstance(self.context["cancel_event"], threading.Event)
|
168
|
-
):
|
169
|
-
return self.context["cancel_event"].is_set()
|
170
|
-
return False
|
171
|
-
|
172
|
-
def call_hooks(self, state: Optional[State] = None):
|
888
|
+
async def call_hooks(self, state: Optional[State] = None):
|
173
889
|
if state is None:
|
174
890
|
state = self.state
|
175
891
|
task = self.task
|
@@ -195,7 +911,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
195
911
|
)
|
196
912
|
result = hook(task, task_run, state)
|
197
913
|
if inspect.isawaitable(result):
|
198
|
-
|
914
|
+
await result
|
199
915
|
except Exception:
|
200
916
|
self.logger.error(
|
201
917
|
f"An error was encountered while running hook {hook_name!r}",
|
@@ -204,71 +920,12 @@ class TaskRunEngine(Generic[P, R]):
|
|
204
920
|
else:
|
205
921
|
self.logger.info(f"Hook {hook_name!r} finished running successfully")
|
206
922
|
|
207
|
-
def
|
208
|
-
key = None
|
209
|
-
if self.task.cache_policy:
|
210
|
-
flow_run_context = FlowRunContext.get()
|
211
|
-
task_run_context = TaskRunContext.get()
|
212
|
-
|
213
|
-
if flow_run_context:
|
214
|
-
parameters = flow_run_context.parameters
|
215
|
-
else:
|
216
|
-
parameters = None
|
217
|
-
|
218
|
-
key = self.task.cache_policy.compute_key(
|
219
|
-
task_ctx=task_run_context,
|
220
|
-
inputs=self.parameters,
|
221
|
-
flow_parameters=parameters,
|
222
|
-
)
|
223
|
-
elif self.task.result_storage_key is not None:
|
224
|
-
key = _format_user_supplied_storage_key(self.task.result_storage_key)
|
225
|
-
return key
|
226
|
-
|
227
|
-
def _resolve_parameters(self):
|
228
|
-
if not self.parameters:
|
229
|
-
return {}
|
230
|
-
|
231
|
-
resolved_parameters = {}
|
232
|
-
for parameter, value in self.parameters.items():
|
233
|
-
try:
|
234
|
-
resolved_parameters[parameter] = visit_collection(
|
235
|
-
value,
|
236
|
-
visit_fn=resolve_to_final_result,
|
237
|
-
return_data=True,
|
238
|
-
max_depth=-1,
|
239
|
-
remove_annotations=True,
|
240
|
-
context={},
|
241
|
-
)
|
242
|
-
except UpstreamTaskError:
|
243
|
-
raise
|
244
|
-
except Exception as exc:
|
245
|
-
raise PrefectException(
|
246
|
-
f"Failed to resolve inputs in parameter {parameter!r}. If your"
|
247
|
-
" parameter type is not supported, consider using the `quote`"
|
248
|
-
" annotation to skip resolution of inputs."
|
249
|
-
) from exc
|
250
|
-
|
251
|
-
self.parameters = resolved_parameters
|
252
|
-
|
253
|
-
def _wait_for_dependencies(self):
|
254
|
-
if not self.wait_for:
|
255
|
-
return
|
256
|
-
|
257
|
-
visit_collection(
|
258
|
-
self.wait_for,
|
259
|
-
visit_fn=resolve_to_final_result,
|
260
|
-
return_data=False,
|
261
|
-
max_depth=-1,
|
262
|
-
remove_annotations=True,
|
263
|
-
context={"current_task_run": self.task_run, "current_task": self.task},
|
264
|
-
)
|
265
|
-
|
266
|
-
def begin_run(self):
|
923
|
+
async def begin_run(self):
|
267
924
|
try:
|
268
925
|
self._resolve_parameters()
|
269
926
|
self._wait_for_dependencies()
|
270
927
|
except UpstreamTaskError as upstream_exc:
|
271
|
-
state = self.set_state(
|
928
|
+
state = await self.set_state(
|
272
929
|
Pending(
|
273
930
|
name="NotReady",
|
274
931
|
message=str(upstream_exc),
|
@@ -291,7 +948,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
291
948
|
flow_run = flow_run_context.flow_run
|
292
949
|
self.task_run.flow_run_run_count = flow_run.run_count
|
293
950
|
|
294
|
-
state = self.set_state(new_state)
|
951
|
+
state = await self.set_state(new_state)
|
295
952
|
|
296
953
|
# TODO: this is temporary until the API stops rejecting state transitions
|
297
954
|
# and the client / transaction store becomes the source of truth
|
@@ -299,11 +956,10 @@ class TaskRunEngine(Generic[P, R]):
|
|
299
956
|
# result reference that no longer exists
|
300
957
|
if state.is_completed():
|
301
958
|
try:
|
302
|
-
state.result(retry_result_failure=False
|
959
|
+
await state.result(retry_result_failure=False)
|
303
960
|
except Exception:
|
304
|
-
state = self.set_state(new_state, force=True)
|
961
|
+
state = await self.set_state(new_state, force=True)
|
305
962
|
|
306
|
-
BACKOFF_MAX = 10
|
307
963
|
backoff_count = 0
|
308
964
|
|
309
965
|
# TODO: Could this listen for state change events instead of polling?
|
@@ -313,10 +969,10 @@ class TaskRunEngine(Generic[P, R]):
|
|
313
969
|
interval = clamped_poisson_interval(
|
314
970
|
average_interval=backoff_count, clamping_factor=0.3
|
315
971
|
)
|
316
|
-
|
317
|
-
state = self.set_state(new_state)
|
972
|
+
await anyio.sleep(interval)
|
973
|
+
state = await self.set_state(new_state)
|
318
974
|
|
319
|
-
def set_state(self, state: State, force: bool = False) -> State:
|
975
|
+
async def set_state(self, state: State, force: bool = False) -> State:
|
320
976
|
last_state = self.state
|
321
977
|
if not self.task_run:
|
322
978
|
raise ValueError("Task run is not set")
|
@@ -332,9 +988,23 @@ class TaskRunEngine(Generic[P, R]):
|
|
332
988
|
self.task_run.state_id = new_state.id
|
333
989
|
self.task_run.state_type = new_state.type
|
334
990
|
self.task_run.state_name = new_state.name
|
991
|
+
|
992
|
+
if new_state.is_final():
|
993
|
+
if (
|
994
|
+
isinstance(new_state.data, BaseResult)
|
995
|
+
and new_state.data.has_cached_object()
|
996
|
+
):
|
997
|
+
# Avoid fetching the result unless it is cached, otherwise we defeat
|
998
|
+
# the purpose of disabling `cache_result_in_memory`
|
999
|
+
result = await new_state.result(raise_on_failure=False, fetch=True)
|
1000
|
+
else:
|
1001
|
+
result = new_state.data
|
1002
|
+
|
1003
|
+
link_state_to_result(new_state, result)
|
1004
|
+
|
335
1005
|
else:
|
336
1006
|
try:
|
337
|
-
new_state =
|
1007
|
+
new_state = await propose_state(
|
338
1008
|
self.client, state, task_run_id=self.task_run.id, force=force
|
339
1009
|
)
|
340
1010
|
except Pause as exc:
|
@@ -361,14 +1031,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
361
1031
|
|
362
1032
|
return new_state
|
363
1033
|
|
364
|
-
def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
|
1034
|
+
async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
|
365
1035
|
if self._return_value is not NotSet:
|
366
1036
|
# if the return value is a BaseResult, we need to fetch it
|
367
1037
|
if isinstance(self._return_value, BaseResult):
|
368
|
-
|
369
|
-
if inspect.isawaitable(_result):
|
370
|
-
_result = run_coro_as_sync(_result)
|
371
|
-
return _result
|
1038
|
+
return await self._return_value.get()
|
372
1039
|
|
373
1040
|
# otherwise, return the value as is
|
374
1041
|
return self._return_value
|
@@ -381,7 +1048,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
381
1048
|
# otherwise, return the exception
|
382
1049
|
return self._raised
|
383
1050
|
|
384
|
-
def handle_success(self, result: R, transaction: Transaction) -> R:
|
1051
|
+
async def handle_success(self, result: R, transaction: Transaction) -> R:
|
385
1052
|
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
386
1053
|
if result_factory is None:
|
387
1054
|
raise ValueError("Result factory is not set")
|
@@ -391,19 +1058,18 @@ class TaskRunEngine(Generic[P, R]):
|
|
391
1058
|
else:
|
392
1059
|
expiration = None
|
393
1060
|
|
394
|
-
terminal_state =
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
defer_persistence=True,
|
402
|
-
)
|
1061
|
+
terminal_state = await return_value_to_state(
|
1062
|
+
result,
|
1063
|
+
result_factory=result_factory,
|
1064
|
+
key=transaction.key,
|
1065
|
+
expiration=expiration,
|
1066
|
+
# defer persistence to transaction commit
|
1067
|
+
defer_persistence=True,
|
403
1068
|
)
|
404
1069
|
transaction.stage(
|
405
1070
|
terminal_state.data,
|
406
|
-
on_rollback_hooks=[
|
1071
|
+
on_rollback_hooks=[self.handle_rollback]
|
1072
|
+
+ [
|
407
1073
|
_with_transaction_hook_logging(hook, "rollback", self.logger)
|
408
1074
|
for hook in self.task.on_rollback_hooks
|
409
1075
|
],
|
@@ -416,18 +1082,18 @@ class TaskRunEngine(Generic[P, R]):
|
|
416
1082
|
terminal_state.name = "Cached"
|
417
1083
|
|
418
1084
|
self.record_terminal_state_timing(terminal_state)
|
419
|
-
self.set_state(terminal_state)
|
1085
|
+
await self.set_state(terminal_state)
|
420
1086
|
self._return_value = result
|
421
1087
|
return result
|
422
1088
|
|
423
|
-
def handle_retry(self, exc: Exception) -> bool:
|
1089
|
+
async def handle_retry(self, exc: Exception) -> bool:
|
424
1090
|
"""Handle any task run retries.
|
425
1091
|
|
426
1092
|
- If the task has retries left, and the retry condition is met, set the task to retrying and return True.
|
427
|
-
|
1093
|
+
- If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
|
428
1094
|
- If the task has no retries left, or the retry condition is not met, return False.
|
429
1095
|
"""
|
430
|
-
if self.retries < self.task.retries and self.can_retry(exc):
|
1096
|
+
if self.retries < self.task.retries and await self.can_retry(exc):
|
431
1097
|
if self.task.retry_delay_seconds:
|
432
1098
|
delay = (
|
433
1099
|
self.task.retry_delay_seconds[
|
@@ -453,7 +1119,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
453
1119
|
str(delay) + " second(s) from now" if delay else "immediately",
|
454
1120
|
)
|
455
1121
|
|
456
|
-
self.set_state(new_state, force=True)
|
1122
|
+
await self.set_state(new_state, force=True)
|
457
1123
|
self.retries = self.retries + 1
|
458
1124
|
return True
|
459
1125
|
elif self.retries >= self.task.retries:
|
@@ -466,24 +1132,22 @@ class TaskRunEngine(Generic[P, R]):
|
|
466
1132
|
|
467
1133
|
return False
|
468
1134
|
|
469
|
-
def handle_exception(self, exc: Exception) -> None:
|
1135
|
+
async def handle_exception(self, exc: Exception) -> None:
|
470
1136
|
# If the task fails, and we have retries left, set the task to retrying.
|
471
|
-
if not self.handle_retry(exc):
|
1137
|
+
if not await self.handle_retry(exc):
|
472
1138
|
# If the task has no retries left, or the retry condition is not met, set the task to failed.
|
473
1139
|
context = TaskRunContext.get()
|
474
|
-
state =
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
result_factory=getattr(context, "result_factory", None),
|
479
|
-
)
|
1140
|
+
state = await exception_to_failed_state(
|
1141
|
+
exc,
|
1142
|
+
message="Task run encountered an exception",
|
1143
|
+
result_factory=getattr(context, "result_factory", None),
|
480
1144
|
)
|
481
1145
|
self.record_terminal_state_timing(state)
|
482
|
-
self.set_state(state)
|
1146
|
+
await self.set_state(state)
|
483
1147
|
self._raised = exc
|
484
1148
|
|
485
|
-
def handle_timeout(self, exc: TimeoutError) -> None:
|
486
|
-
if not self.handle_retry(exc):
|
1149
|
+
async def handle_timeout(self, exc: TimeoutError) -> None:
|
1150
|
+
if not await self.handle_retry(exc):
|
487
1151
|
if isinstance(exc, TaskRunTimeoutError):
|
488
1152
|
message = f"Task run exceeded timeout of {self.task.timeout_seconds} second(s)"
|
489
1153
|
else:
|
@@ -494,29 +1158,19 @@ class TaskRunEngine(Generic[P, R]):
|
|
494
1158
|
message=message,
|
495
1159
|
name="TimedOut",
|
496
1160
|
)
|
497
|
-
self.set_state(state)
|
1161
|
+
await self.set_state(state)
|
498
1162
|
self._raised = exc
|
499
1163
|
|
500
|
-
def handle_crash(self, exc: BaseException) -> None:
|
501
|
-
state =
|
1164
|
+
async def handle_crash(self, exc: BaseException) -> None:
|
1165
|
+
state = await exception_to_crashed_state(exc)
|
502
1166
|
self.logger.error(f"Crash detected! {state.message}")
|
503
1167
|
self.logger.debug("Crash details:", exc_info=exc)
|
504
1168
|
self.record_terminal_state_timing(state)
|
505
|
-
self.set_state(state, force=True)
|
1169
|
+
await self.set_state(state, force=True)
|
506
1170
|
self._raised = exc
|
507
1171
|
|
508
|
-
|
509
|
-
|
510
|
-
if self.task_run.start_time and not self.task_run.end_time:
|
511
|
-
self.task_run.end_time = state.timestamp
|
512
|
-
|
513
|
-
if self.task_run.state.is_running():
|
514
|
-
self.task_run.total_run_time += (
|
515
|
-
state.timestamp - self.task_run.state.timestamp
|
516
|
-
)
|
517
|
-
|
518
|
-
@contextmanager
|
519
|
-
def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
|
1172
|
+
@asynccontextmanager
|
1173
|
+
async def setup_run_context(self, client: Optional[PrefectClient] = None):
|
520
1174
|
from prefect.utilities.engine import (
|
521
1175
|
_resolve_custom_task_run_name,
|
522
1176
|
should_log_prints,
|
@@ -528,7 +1182,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
528
1182
|
raise ValueError("Task run is not set")
|
529
1183
|
|
530
1184
|
if not PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
531
|
-
self.task_run = client.read_task_run(self.task_run.id)
|
1185
|
+
self.task_run = await client.read_task_run(self.task_run.id)
|
532
1186
|
with ExitStack() as stack:
|
533
1187
|
if log_prints := should_log_prints(self.task):
|
534
1188
|
stack.enter_context(patch_print())
|
@@ -538,10 +1192,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
538
1192
|
log_prints=log_prints,
|
539
1193
|
task_run=self.task_run,
|
540
1194
|
parameters=self.parameters,
|
541
|
-
result_factory=
|
1195
|
+
result_factory=await ResultFactory.from_task(self.task), # type: ignore
|
542
1196
|
client=client,
|
543
1197
|
)
|
544
1198
|
)
|
1199
|
+
stack.enter_context(ConcurrencyContext())
|
545
1200
|
|
546
1201
|
self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
|
547
1202
|
|
@@ -551,7 +1206,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
551
1206
|
task_run_name = _resolve_custom_task_run_name(
|
552
1207
|
task=self.task, parameters=self.parameters
|
553
1208
|
)
|
554
|
-
self.client.set_task_run_name(
|
1209
|
+
await self.client.set_task_run_name(
|
555
1210
|
task_run_id=self.task_run.id, name=task_run_name
|
556
1211
|
)
|
557
1212
|
self.logger.extra["task_run_name"] = task_run_name
|
@@ -562,19 +1217,19 @@ class TaskRunEngine(Generic[P, R]):
|
|
562
1217
|
self._task_name_set = True
|
563
1218
|
yield
|
564
1219
|
|
565
|
-
@
|
566
|
-
def initialize_run(
|
1220
|
+
@asynccontextmanager
|
1221
|
+
async def initialize_run(
|
567
1222
|
self,
|
568
1223
|
task_run_id: Optional[UUID] = None,
|
569
1224
|
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
570
|
-
) ->
|
1225
|
+
) -> AsyncGenerator["AsyncTaskRunEngine", Any]:
|
571
1226
|
"""
|
572
1227
|
Enters a client context and creates a task run if needed.
|
573
1228
|
"""
|
574
1229
|
|
575
1230
|
with hydrated_context(self.context):
|
576
|
-
with
|
577
|
-
self._client =
|
1231
|
+
async with AsyncClientContext.get_or_create():
|
1232
|
+
self._client = get_client()
|
578
1233
|
self._is_started = True
|
579
1234
|
try:
|
580
1235
|
if PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION:
|
@@ -594,37 +1249,39 @@ class TaskRunEngine(Generic[P, R]):
|
|
594
1249
|
self.task_run.name = task_run_name
|
595
1250
|
|
596
1251
|
if not self.task_run:
|
597
|
-
self.task_run =
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
1252
|
+
self.task_run = await self.task.create_local_run(
|
1253
|
+
id=task_run_id,
|
1254
|
+
parameters=self.parameters,
|
1255
|
+
flow_run_context=FlowRunContext.get(),
|
1256
|
+
parent_task_run_context=TaskRunContext.get(),
|
1257
|
+
wait_for=self.wait_for,
|
1258
|
+
extra_task_inputs=dependencies,
|
1259
|
+
task_run_name=task_run_name,
|
1260
|
+
)
|
1261
|
+
# Emit an event to capture that the task run was in the `PENDING` state.
|
1262
|
+
self._last_event = emit_task_run_state_change_event(
|
1263
|
+
task_run=self.task_run,
|
1264
|
+
initial_state=None,
|
1265
|
+
validated_state=self.task_run.state,
|
607
1266
|
)
|
608
1267
|
else:
|
609
1268
|
if not self.task_run:
|
610
|
-
self.task_run =
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
1269
|
+
self.task_run = await self.task.create_run(
|
1270
|
+
id=task_run_id,
|
1271
|
+
parameters=self.parameters,
|
1272
|
+
flow_run_context=FlowRunContext.get(),
|
1273
|
+
parent_task_run_context=TaskRunContext.get(),
|
1274
|
+
wait_for=self.wait_for,
|
1275
|
+
extra_task_inputs=dependencies,
|
1276
|
+
)
|
1277
|
+
# Emit an event to capture that the task run was in the `PENDING` state.
|
1278
|
+
self._last_event = emit_task_run_state_change_event(
|
1279
|
+
task_run=self.task_run,
|
1280
|
+
initial_state=None,
|
1281
|
+
validated_state=self.task_run.state,
|
619
1282
|
)
|
620
|
-
# Emit an event to capture that the task run was in the `PENDING` state.
|
621
|
-
self._last_event = emit_task_run_state_change_event(
|
622
|
-
task_run=self.task_run,
|
623
|
-
initial_state=None,
|
624
|
-
validated_state=self.task_run.state,
|
625
|
-
)
|
626
1283
|
|
627
|
-
with self.setup_run_context():
|
1284
|
+
async with self.setup_run_context():
|
628
1285
|
# setup_run_context might update the task run name, so log creation here
|
629
1286
|
self.logger.info(
|
630
1287
|
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
|
@@ -633,7 +1290,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
633
1290
|
|
634
1291
|
except TerminationSignal as exc:
|
635
1292
|
# TerminationSignals are caught and handled as crashes
|
636
|
-
self.handle_crash(exc)
|
1293
|
+
await self.handle_crash(exc)
|
637
1294
|
raise exc
|
638
1295
|
|
639
1296
|
except Exception:
|
@@ -649,60 +1306,19 @@ class TaskRunEngine(Generic[P, R]):
|
|
649
1306
|
raise
|
650
1307
|
except BaseException as exc:
|
651
1308
|
# BaseExceptions are caught and handled as crashes
|
652
|
-
self.handle_crash(exc)
|
1309
|
+
await self.handle_crash(exc)
|
653
1310
|
raise
|
654
1311
|
finally:
|
655
|
-
|
656
|
-
display_state = (
|
657
|
-
repr(self.state) if PREFECT_DEBUG_MODE else str(self.state)
|
658
|
-
)
|
659
|
-
level = logging.INFO if self.state.is_completed() else logging.ERROR
|
660
|
-
msg = f"Finished in state {display_state}"
|
661
|
-
if self.state.is_pending():
|
662
|
-
msg += (
|
663
|
-
"\nPlease wait for all submitted tasks to complete"
|
664
|
-
" before exiting your flow by calling `.wait()` on the "
|
665
|
-
"`PrefectFuture` returned from your `.submit()` calls."
|
666
|
-
)
|
667
|
-
msg += dedent(
|
668
|
-
"""
|
669
|
-
|
670
|
-
Example:
|
671
|
-
|
672
|
-
from prefect import flow, task
|
673
|
-
|
674
|
-
@task
|
675
|
-
def say_hello(name):
|
676
|
-
print f"Hello, {name}!"
|
677
|
-
|
678
|
-
@flow
|
679
|
-
def example_flow():
|
680
|
-
future = say_hello.submit(name="Marvin)
|
681
|
-
future.wait()
|
682
|
-
|
683
|
-
example_flow()
|
684
|
-
"""
|
685
|
-
)
|
686
|
-
self.logger.log(
|
687
|
-
level=level,
|
688
|
-
msg=msg,
|
689
|
-
)
|
690
|
-
|
1312
|
+
self.log_finished_message()
|
691
1313
|
self._is_started = False
|
692
1314
|
self._client = None
|
693
1315
|
|
694
|
-
def is_running(self) -> bool:
|
695
|
-
"""Whether or not the engine is currently running a task."""
|
696
|
-
if (task_run := getattr(self, "task_run", None)) is None:
|
697
|
-
return False
|
698
|
-
return task_run.state.is_running() or task_run.state.is_scheduled()
|
699
|
-
|
700
1316
|
async def wait_until_ready(self):
|
701
1317
|
"""Waits until the scheduled time (if its the future), then enters Running."""
|
702
1318
|
if scheduled_time := self.state.state_details.scheduled_time:
|
703
1319
|
sleep_time = (scheduled_time - pendulum.now("utc")).total_seconds()
|
704
1320
|
await anyio.sleep(sleep_time if sleep_time > 0 else 0)
|
705
|
-
self.set_state(
|
1321
|
+
await self.set_state(
|
706
1322
|
Retrying() if self.state.name == "AwaitingRetry" else Running(),
|
707
1323
|
force=True,
|
708
1324
|
)
|
@@ -713,21 +1329,23 @@ class TaskRunEngine(Generic[P, R]):
|
|
713
1329
|
#
|
714
1330
|
# --------------------------
|
715
1331
|
|
716
|
-
@
|
717
|
-
def start(
|
1332
|
+
@asynccontextmanager
|
1333
|
+
async def start(
|
718
1334
|
self,
|
719
1335
|
task_run_id: Optional[UUID] = None,
|
720
1336
|
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
721
|
-
) ->
|
722
|
-
with self.initialize_run(
|
723
|
-
|
1337
|
+
) -> AsyncGenerator[None, None]:
|
1338
|
+
async with self.initialize_run(
|
1339
|
+
task_run_id=task_run_id, dependencies=dependencies
|
1340
|
+
):
|
1341
|
+
await self.begin_run()
|
724
1342
|
try:
|
725
1343
|
yield
|
726
1344
|
finally:
|
727
|
-
self.call_hooks()
|
1345
|
+
await self.call_hooks()
|
728
1346
|
|
729
|
-
@
|
730
|
-
def transaction_context(self) ->
|
1347
|
+
@asynccontextmanager
|
1348
|
+
async def transaction_context(self) -> AsyncGenerator[Transaction, None]:
|
731
1349
|
result_factory = getattr(TaskRunContext.get(), "result_factory", None)
|
732
1350
|
|
733
1351
|
# refresh cache setting is now repurposes as overwrite transaction record
|
@@ -744,13 +1362,12 @@ class TaskRunEngine(Generic[P, R]):
|
|
744
1362
|
) as txn:
|
745
1363
|
yield txn
|
746
1364
|
|
747
|
-
@
|
748
|
-
def run_context(self):
|
749
|
-
timeout_context = timeout_async if self.task.isasync else timeout
|
1365
|
+
@asynccontextmanager
|
1366
|
+
async def run_context(self):
|
750
1367
|
# reenter the run context to ensure it is up to date for every run
|
751
|
-
with self.setup_run_context():
|
1368
|
+
async with self.setup_run_context():
|
752
1369
|
try:
|
753
|
-
with
|
1370
|
+
with timeout_async(
|
754
1371
|
seconds=self.task.timeout_seconds,
|
755
1372
|
timeout_exc_type=TaskRunTimeoutError,
|
756
1373
|
):
|
@@ -762,11 +1379,11 @@ class TaskRunEngine(Generic[P, R]):
|
|
762
1379
|
|
763
1380
|
yield self
|
764
1381
|
except TimeoutError as exc:
|
765
|
-
self.handle_timeout(exc)
|
1382
|
+
await self.handle_timeout(exc)
|
766
1383
|
except Exception as exc:
|
767
|
-
self.handle_exception(exc)
|
1384
|
+
await self.handle_exception(exc)
|
768
1385
|
|
769
|
-
def call_task_fn(
|
1386
|
+
async def call_task_fn(
|
770
1387
|
self, transaction: Transaction
|
771
1388
|
) -> Union[R, Coroutine[Any, Any, R]]:
|
772
1389
|
"""
|
@@ -774,24 +1391,23 @@ class TaskRunEngine(Generic[P, R]):
|
|
774
1391
|
task is async.
|
775
1392
|
"""
|
776
1393
|
parameters = self.parameters or {}
|
777
|
-
if
|
778
|
-
|
779
|
-
async def _call_task_fn():
|
780
|
-
if transaction.is_committed():
|
781
|
-
result = transaction.read()
|
782
|
-
else:
|
783
|
-
result = await call_with_parameters(self.task.fn, parameters)
|
784
|
-
self.handle_success(result, transaction=transaction)
|
785
|
-
return result
|
786
|
-
|
787
|
-
return _call_task_fn()
|
1394
|
+
if transaction.is_committed():
|
1395
|
+
result = transaction.read()
|
788
1396
|
else:
|
789
|
-
if
|
790
|
-
|
1397
|
+
if (
|
1398
|
+
PREFECT_EXPERIMENTAL_ENABLE_CLIENT_SIDE_TASK_ORCHESTRATION.value()
|
1399
|
+
and self.task.tags
|
1400
|
+
):
|
1401
|
+
# Acquire a concurrency slot for each tag, but only if a limit
|
1402
|
+
# matching the tag already exists.
|
1403
|
+
async with aconcurrency(
|
1404
|
+
list(self.task.tags), occupy=1, create_if_missing=False
|
1405
|
+
):
|
1406
|
+
result = await call_with_parameters(self.task.fn, parameters)
|
791
1407
|
else:
|
792
|
-
result = call_with_parameters(self.task.fn, parameters)
|
793
|
-
|
794
|
-
|
1408
|
+
result = await call_with_parameters(self.task.fn, parameters)
|
1409
|
+
await self.handle_success(result, transaction=transaction)
|
1410
|
+
return result
|
795
1411
|
|
796
1412
|
|
797
1413
|
def run_task_sync(
|
@@ -804,7 +1420,7 @@ def run_task_sync(
|
|
804
1420
|
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
805
1421
|
context: Optional[Dict[str, Any]] = None,
|
806
1422
|
) -> Union[R, State, None]:
|
807
|
-
engine =
|
1423
|
+
engine = SyncTaskRunEngine[P, R](
|
808
1424
|
task=task,
|
809
1425
|
parameters=parameters,
|
810
1426
|
task_run=task_run,
|
@@ -831,7 +1447,7 @@ async def run_task_async(
|
|
831
1447
|
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
|
832
1448
|
context: Optional[Dict[str, Any]] = None,
|
833
1449
|
) -> Union[R, State, None]:
|
834
|
-
engine =
|
1450
|
+
engine = AsyncTaskRunEngine[P, R](
|
835
1451
|
task=task,
|
836
1452
|
parameters=parameters,
|
837
1453
|
task_run=task_run,
|
@@ -839,13 +1455,13 @@ async def run_task_async(
|
|
839
1455
|
context=context,
|
840
1456
|
)
|
841
1457
|
|
842
|
-
with engine.start(task_run_id=task_run_id, dependencies=dependencies):
|
1458
|
+
async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
|
843
1459
|
while engine.is_running():
|
844
1460
|
await engine.wait_until_ready()
|
845
|
-
with engine.run_context(), engine.transaction_context() as txn:
|
1461
|
+
async with engine.run_context(), engine.transaction_context() as txn:
|
846
1462
|
await engine.call_task_fn(txn)
|
847
1463
|
|
848
|
-
return engine.state if return_type == "state" else engine.result()
|
1464
|
+
return engine.state if return_type == "state" else await engine.result()
|
849
1465
|
|
850
1466
|
|
851
1467
|
def run_generator_task_sync(
|
@@ -861,7 +1477,7 @@ def run_generator_task_sync(
|
|
861
1477
|
if return_type != "result":
|
862
1478
|
raise ValueError("The return_type for a generator task must be 'result'")
|
863
1479
|
|
864
|
-
engine =
|
1480
|
+
engine = SyncTaskRunEngine[P, R](
|
865
1481
|
task=task,
|
866
1482
|
parameters=parameters,
|
867
1483
|
task_run=task_run,
|
@@ -915,7 +1531,7 @@ async def run_generator_task_async(
|
|
915
1531
|
) -> AsyncGenerator[R, None]:
|
916
1532
|
if return_type != "result":
|
917
1533
|
raise ValueError("The return_type for a generator task must be 'result'")
|
918
|
-
engine =
|
1534
|
+
engine = AsyncTaskRunEngine[P, R](
|
919
1535
|
task=task,
|
920
1536
|
parameters=parameters,
|
921
1537
|
task_run=task_run,
|
@@ -923,10 +1539,10 @@ async def run_generator_task_async(
|
|
923
1539
|
context=context,
|
924
1540
|
)
|
925
1541
|
|
926
|
-
with engine.start(task_run_id=task_run_id, dependencies=dependencies):
|
1542
|
+
async with engine.start(task_run_id=task_run_id, dependencies=dependencies):
|
927
1543
|
while engine.is_running():
|
928
1544
|
await engine.wait_until_ready()
|
929
|
-
with engine.run_context(), engine.transaction_context() as txn:
|
1545
|
+
async with engine.run_context(), engine.transaction_context() as txn:
|
930
1546
|
# TODO: generators should default to commit_mode=OFF
|
931
1547
|
# because they are dynamic by definition
|
932
1548
|
# for now we just prevent this branch explicitly
|
@@ -950,13 +1566,13 @@ async def run_generator_task_async(
|
|
950
1566
|
link_state_to_result(engine.state, gen_result)
|
951
1567
|
yield gen_result
|
952
1568
|
except (StopAsyncIteration, GeneratorExit) as exc:
|
953
|
-
engine.handle_success(None, transaction=txn)
|
1569
|
+
await engine.handle_success(None, transaction=txn)
|
954
1570
|
if isinstance(exc, GeneratorExit):
|
955
1571
|
gen.throw(exc)
|
956
1572
|
|
957
1573
|
# async generators can't return, but we can raise failures here
|
958
1574
|
if engine.state.is_failed():
|
959
|
-
engine.result()
|
1575
|
+
await engine.result()
|
960
1576
|
|
961
1577
|
|
962
1578
|
def run_task(
|