prefect-client 3.0.0rc1__py3-none-any.whl → 3.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefect/settings.py CHANGED
@@ -1208,6 +1208,9 @@ PREFECT_API_SERVICES_FOREMAN_WORK_QUEUE_LAST_POLLED_TIMEOUT_SECONDS = Setting(
1208
1208
  """The number of seconds before a work queue is marked as not ready if it has not been
1209
1209
  polled."""
1210
1210
 
1211
+ PREFECT_API_LOG_RETRYABLE_ERRORS = Setting(bool, default=False)
1212
+ """If `True`, log retryable errors in the API and it's services."""
1213
+
1211
1214
 
1212
1215
  PREFECT_API_DEFAULT_LIMIT = Setting(
1213
1216
  int,
@@ -1527,9 +1530,9 @@ PREFECT_TASK_SCHEDULING_PENDING_TASK_TIMEOUT = Setting(
1527
1530
  default=timedelta(seconds=30),
1528
1531
  )
1529
1532
  """
1530
- How long before a PENDING task are made available to another task server. In practice,
1531
- a task server should move a task from PENDING to RUNNING very quickly, so runs stuck in
1532
- PENDING for a while is a sign that the task server may have crashed.
1533
+ How long before a PENDING task are made available to another task worker. In practice,
1534
+ a task worker should move a task from PENDING to RUNNING very quickly, so runs stuck in
1535
+ PENDING for a while is a sign that the task worker may have crashed.
1533
1536
  """
1534
1537
 
1535
1538
  PREFECT_EXPERIMENTAL_ENABLE_EXTRA_RUNNER_ENDPOINTS = Setting(bool, default=False)
prefect/states.py CHANGED
@@ -204,7 +204,9 @@ async def exception_to_failed_state(
204
204
  return state
205
205
 
206
206
 
207
- async def return_value_to_state(retval: R, result_factory: ResultFactory) -> State[R]:
207
+ async def return_value_to_state(
208
+ retval: R, result_factory: ResultFactory, key: str = None
209
+ ) -> State[R]:
208
210
  """
209
211
  Given a return value from a user's function, create a `State` the run should
210
212
  be placed in.
@@ -236,7 +238,7 @@ async def return_value_to_state(retval: R, result_factory: ResultFactory) -> Sta
236
238
  # Unless the user has already constructed a result explicitly, use the factory
237
239
  # to update the data to the correct type
238
240
  if not isinstance(state.data, BaseResult):
239
- state.data = await result_factory.create_result(state.data)
241
+ state.data = await result_factory.create_result(state.data, key=key)
240
242
 
241
243
  return state
242
244
 
@@ -276,7 +278,7 @@ async def return_value_to_state(retval: R, result_factory: ResultFactory) -> Sta
276
278
  return State(
277
279
  type=new_state_type,
278
280
  message=message,
279
- data=await result_factory.create_result(retval),
281
+ data=await result_factory.create_result(retval, key=key),
280
282
  )
281
283
 
282
284
  # Generators aren't portable, implicitly convert them to a list.
@@ -289,7 +291,7 @@ async def return_value_to_state(retval: R, result_factory: ResultFactory) -> Sta
289
291
  if isinstance(data, BaseResult):
290
292
  return Completed(data=data)
291
293
  else:
292
- return Completed(data=await result_factory.create_result(data))
294
+ return Completed(data=await result_factory.create_result(data, key=key))
293
295
 
294
296
 
295
297
  @sync_compatible
prefect/task_engine.py CHANGED
@@ -13,12 +13,14 @@ from typing import (
13
13
  Iterable,
14
14
  Literal,
15
15
  Optional,
16
+ Sequence,
16
17
  Set,
17
18
  TypeVar,
18
19
  Union,
19
20
  )
20
21
  from uuid import UUID
21
22
 
23
+ import anyio
22
24
  import pendulum
23
25
  from typing_extensions import ParamSpec
24
26
 
@@ -50,19 +52,19 @@ from prefect.settings import (
50
52
  PREFECT_TASKS_REFRESH_CACHE,
51
53
  )
52
54
  from prefect.states import (
55
+ AwaitingRetry,
53
56
  Failed,
54
57
  Paused,
55
58
  Pending,
56
59
  Retrying,
57
60
  Running,
58
- StateDetails,
59
61
  exception_to_crashed_state,
60
62
  exception_to_failed_state,
61
63
  return_value_to_state,
62
64
  )
63
65
  from prefect.transactions import Transaction, transaction
64
66
  from prefect.utilities.asyncutils import run_coro_as_sync
65
- from prefect.utilities.callables import parameters_to_args_kwargs
67
+ from prefect.utilities.callables import call_with_parameters
66
68
  from prefect.utilities.collections import visit_collection
67
69
  from prefect.utilities.engine import (
68
70
  _get_hook_name,
@@ -133,101 +135,54 @@ class TaskRunEngine(Generic[P, R]):
133
135
  )
134
136
  return False
135
137
 
136
- def get_hooks(self, state: State, as_async: bool = False) -> Iterable[Callable]:
138
+ def call_hooks(self, state: State = None) -> Iterable[Callable]:
139
+ if state is None:
140
+ state = self.state
137
141
  task = self.task
138
142
  task_run = self.task_run
139
143
 
140
144
  if not task_run:
141
145
  raise ValueError("Task run is not set")
142
146
 
143
- hooks = None
144
147
  if state.is_failed() and task.on_failure_hooks:
145
148
  hooks = task.on_failure_hooks
146
149
  elif state.is_completed() and task.on_completion_hooks:
147
150
  hooks = task.on_completion_hooks
151
+ else:
152
+ hooks = None
148
153
 
149
154
  for hook in hooks or []:
150
155
  hook_name = _get_hook_name(hook)
151
156
 
152
- @contextmanager
153
- def hook_context():
154
- try:
155
- self.logger.info(
156
- f"Running hook {hook_name!r} in response to entering state"
157
- f" {state.name!r}"
158
- )
159
- yield
160
- except Exception:
161
- self.logger.error(
162
- f"An error was encountered while running hook {hook_name!r}",
163
- exc_info=True,
164
- )
165
- else:
166
- self.logger.info(
167
- f"Hook {hook_name!r} finished running successfully"
168
- )
169
-
170
- if as_async:
171
-
172
- async def _hook_fn():
173
- with hook_context():
174
- result = hook(task, task_run, state)
175
- if inspect.isawaitable(result):
176
- await result
177
-
157
+ try:
158
+ self.logger.info(
159
+ f"Running hook {hook_name!r} in response to entering state"
160
+ f" {state.name!r}"
161
+ )
162
+ result = hook(task, task_run, state)
163
+ if inspect.isawaitable(result):
164
+ run_coro_as_sync(result)
165
+ except Exception:
166
+ self.logger.error(
167
+ f"An error was encountered while running hook {hook_name!r}",
168
+ exc_info=True,
169
+ )
178
170
  else:
179
-
180
- def _hook_fn():
181
- with hook_context():
182
- result = hook(task, task_run, state)
183
- if inspect.isawaitable(result):
184
- run_coro_as_sync(result)
185
-
186
- yield _hook_fn
171
+ self.logger.info(f"Hook {hook_name!r} finished running successfully")
187
172
 
188
173
  def compute_transaction_key(self) -> str:
189
- if self.task.result_storage_key is not None:
174
+ key = None
175
+ if self.task.cache_policy:
176
+ task_run_context = TaskRunContext.get()
177
+ key = self.task.cache_policy.compute_key(
178
+ task_ctx=task_run_context,
179
+ inputs=self.parameters,
180
+ flow_parameters=None,
181
+ )
182
+ elif self.task.result_storage_key is not None:
190
183
  key = _format_user_supplied_storage_key(self.task.result_storage_key)
191
- else:
192
- key = str(self.task_run.id)
193
184
  return key
194
185
 
195
- def _compute_state_details(
196
- self, include_cache_expiration: bool = False
197
- ) -> StateDetails:
198
- task_run_context = TaskRunContext.get()
199
- ## setup cache metadata
200
- cache_key = (
201
- self.task.cache_key_fn(
202
- task_run_context,
203
- self.parameters or {},
204
- )
205
- if self.task.cache_key_fn
206
- else None
207
- )
208
- # Ignore the cached results for a cache key, default = false
209
- # Setting on task level overrules the Prefect setting (env var)
210
- refresh_cache = (
211
- self.task.refresh_cache
212
- if self.task.refresh_cache is not None
213
- else PREFECT_TASKS_REFRESH_CACHE.value()
214
- )
215
-
216
- if include_cache_expiration:
217
- cache_expiration = (
218
- (pendulum.now("utc") + self.task.cache_expiration)
219
- if self.task.cache_expiration
220
- else None
221
- )
222
- else:
223
- cache_expiration = None
224
-
225
- return StateDetails(
226
- cache_key=cache_key,
227
- refresh_cache=refresh_cache,
228
- cache_expiration=cache_expiration,
229
- )
230
-
231
186
  def _resolve_parameters(self):
232
187
  if not self.parameters:
233
188
  return {}
@@ -283,8 +238,7 @@ class TaskRunEngine(Generic[P, R]):
283
238
  )
284
239
  return
285
240
 
286
- state_details = self._compute_state_details()
287
- new_state = Running(state_details=state_details)
241
+ new_state = Running()
288
242
  state = self.set_state(new_state)
289
243
 
290
244
  BACKOFF_MAX = 10
@@ -344,17 +298,9 @@ class TaskRunEngine(Generic[P, R]):
344
298
  if result_factory is None:
345
299
  raise ValueError("Result factory is not set")
346
300
 
347
- # dont put this inside function, else the transaction could get serialized
348
- key = transaction.key
349
-
350
- def key_fn():
351
- return key
352
-
353
- result_factory.storage_key_fn = key_fn
354
301
  terminal_state = run_coro_as_sync(
355
302
  return_value_to_state(
356
- result,
357
- result_factory=result_factory,
303
+ result, result_factory=result_factory, key=transaction.key
358
304
  )
359
305
  )
360
306
  transaction.stage(
@@ -362,20 +308,33 @@ class TaskRunEngine(Generic[P, R]):
362
308
  on_rollback_hooks=self.task.on_rollback_hooks,
363
309
  on_commit_hooks=self.task.on_commit_hooks,
364
310
  )
365
- terminal_state.state_details = self._compute_state_details(
366
- include_cache_expiration=True
367
- )
311
+ if transaction.is_committed():
312
+ terminal_state.name = "Cached"
368
313
  self.set_state(terminal_state)
369
314
  return result
370
315
 
371
316
  def handle_retry(self, exc: Exception) -> bool:
372
- """
373
- If the task has retries left, and the retry condition is met, set the task to retrying.
317
+ """Handle any task run retries.
318
+
319
+ - If the task has retries left, and the retry condition is met, set the task to retrying and return True.
320
+ - If the task has a retry delay, place in AwaitingRetry state with a delayed scheduled time.
374
321
  - If the task has no retries left, or the retry condition is not met, return False.
375
- - If the task has retries left, and the retry condition is met, return True.
376
322
  """
377
323
  if self.retries < self.task.retries and self.can_retry:
378
- self.set_state(Retrying(), force=True)
324
+ if self.task.retry_delay_seconds:
325
+ delay = (
326
+ self.task.retry_delay_seconds[
327
+ min(self.retries, len(self.task.retry_delay_seconds) - 1)
328
+ ] # repeat final delay value if attempts exceed specified delays
329
+ if isinstance(self.task.retry_delay_seconds, Sequence)
330
+ else self.task.retry_delay_seconds
331
+ )
332
+ new_state = AwaitingRetry(
333
+ scheduled_time=pendulum.now("utc").add(seconds=delay)
334
+ )
335
+ else:
336
+ new_state = Retrying()
337
+ self.set_state(new_state, force=True)
379
338
  self.retries = self.retries + 1
380
339
  return True
381
340
  return False
@@ -461,7 +420,7 @@ class TaskRunEngine(Generic[P, R]):
461
420
  yield
462
421
 
463
422
  @contextmanager
464
- def start(
423
+ def initialize_run(
465
424
  self,
466
425
  task_run_id: Optional[UUID] = None,
467
426
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
@@ -496,12 +455,16 @@ class TaskRunEngine(Generic[P, R]):
496
455
  )
497
456
 
498
457
  yield self
458
+
499
459
  except Exception:
500
460
  # regular exceptions are caught and re-raised to the user
501
461
  raise
502
462
  except (Pause, Abort):
503
463
  # Do not capture internal signals as crashes
504
464
  raise
465
+ except GeneratorExit:
466
+ # Do not capture generator exits as crashes
467
+ raise
505
468
  except BaseException as exc:
506
469
  # BaseExceptions are caught and handled as crashes
507
470
  self.handle_crash(exc)
@@ -528,9 +491,100 @@ class TaskRunEngine(Generic[P, R]):
528
491
  self._client = None
529
492
 
530
493
  def is_running(self) -> bool:
531
- if getattr(self, "task_run", None) is None:
494
+ """Whether or not the engine is currently running a task."""
495
+ if (task_run := getattr(self, "task_run", None)) is None:
532
496
  return False
533
- return getattr(self, "task_run").state.is_running()
497
+ return task_run.state.is_running() or task_run.state.is_scheduled()
498
+
499
+ async def wait_until_ready(self):
500
+ """Waits until the scheduled time (if its the future), then enters Running."""
501
+ if scheduled_time := self.state.state_details.scheduled_time:
502
+ self.logger.info(
503
+ f"Waiting for scheduled time {scheduled_time} for task {self.task.name!r}"
504
+ )
505
+ await anyio.sleep((scheduled_time - pendulum.now("utc")).total_seconds())
506
+ self.set_state(
507
+ Retrying() if self.state.name == "AwaitingRetry" else Running(),
508
+ force=True,
509
+ )
510
+
511
+ # --------------------------
512
+ #
513
+ # The following methods compose the main task run loop
514
+ #
515
+ # --------------------------
516
+
517
+ @contextmanager
518
+ def start(
519
+ self,
520
+ task_run_id: Optional[UUID] = None,
521
+ dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
522
+ ) -> Generator[None, None, None]:
523
+ with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
524
+ with self.enter_run_context():
525
+ self.logger.debug(
526
+ f"Executing task {self.task.name!r} for task run {self.task_run.name!r}..."
527
+ )
528
+ self.begin_run()
529
+ try:
530
+ yield
531
+ finally:
532
+ self.call_hooks()
533
+
534
+ @contextmanager
535
+ def transaction_context(self) -> Generator[Transaction, None, None]:
536
+ result_factory = getattr(TaskRunContext.get(), "result_factory", None)
537
+
538
+ # refresh cache setting is now repurposes as overwrite transaction record
539
+ overwrite = (
540
+ self.task.refresh_cache
541
+ if self.task.refresh_cache is not None
542
+ else PREFECT_TASKS_REFRESH_CACHE.value()
543
+ )
544
+ with transaction(
545
+ key=self.compute_transaction_key(),
546
+ store=ResultFactoryStore(result_factory=result_factory),
547
+ overwrite=overwrite,
548
+ ) as txn:
549
+ yield txn
550
+
551
+ @contextmanager
552
+ def run_context(self):
553
+ timeout_context = timeout_async if self.task.isasync else timeout
554
+ # reenter the run context to ensure it is up to date for every run
555
+ with self.enter_run_context():
556
+ try:
557
+ with timeout_context(seconds=self.task.timeout_seconds):
558
+ yield self
559
+ except TimeoutError as exc:
560
+ self.handle_timeout(exc)
561
+ except Exception as exc:
562
+ self.handle_exception(exc)
563
+
564
+ def call_task_fn(
565
+ self, transaction: Transaction
566
+ ) -> Union[R, Coroutine[Any, Any, R]]:
567
+ """
568
+ Convenience method to call the task function. Returns a coroutine if the
569
+ task is async.
570
+ """
571
+ parameters = self.parameters or {}
572
+ if self.task.isasync:
573
+
574
+ async def _call_task_fn():
575
+ if transaction.is_committed():
576
+ result = transaction.read()
577
+ else:
578
+ result = await call_with_parameters(self.task.fn, parameters)
579
+ self.handle_success(result, transaction=transaction)
580
+
581
+ return _call_task_fn()
582
+ else:
583
+ if transaction.is_committed():
584
+ result = transaction.read()
585
+ else:
586
+ result = call_with_parameters(self.task.fn, parameters)
587
+ self.handle_success(result, transaction=transaction)
534
588
 
535
589
 
536
590
  def run_task_sync(
@@ -550,56 +604,18 @@ def run_task_sync(
550
604
  wait_for=wait_for,
551
605
  context=context,
552
606
  )
553
- # This is a context manager that keeps track of the run of the task run.
554
- with engine.start(task_run_id=task_run_id, dependencies=dependencies) as run:
555
- with run.enter_run_context():
556
- run.begin_run()
557
- while run.is_running():
558
- # enter run context on each loop iteration to ensure the context
559
- # contains the latest task run metadata
560
- with run.enter_run_context():
561
- try:
562
- # This is where the task is actually run.
563
- with timeout(seconds=run.task.timeout_seconds):
564
- call_args, call_kwargs = parameters_to_args_kwargs(
565
- task.fn, run.parameters or {}
566
- )
567
- run.logger.debug(
568
- f"Executing task {task.name!r} for task run {run.task_run.name!r}..."
569
- )
570
- result_factory = getattr(
571
- TaskRunContext.get(), "result_factory", None
572
- )
573
- with transaction(
574
- key=run.compute_transaction_key(),
575
- store=ResultFactoryStore(result_factory=result_factory),
576
- ) as txn:
577
- if txn.is_committed():
578
- result = txn.read()
579
- else:
580
- result = task.fn(*call_args, **call_kwargs) # type: ignore
581
-
582
- # If the task run is successful, finalize it.
583
- # do this within the transaction lifecycle
584
- # in order to get the proper result serialization
585
- run.handle_success(result, transaction=txn)
586
-
587
- except TimeoutError as exc:
588
- run.handle_timeout(exc)
589
- except Exception as exc:
590
- run.handle_exception(exc)
591
-
592
- if run.state.is_final():
593
- for hook in run.get_hooks(run.state):
594
- hook()
595
-
596
- if return_type == "state":
597
- return run.state
598
- return run.result()
607
+
608
+ with engine.start(task_run_id=task_run_id, dependencies=dependencies):
609
+ while engine.is_running():
610
+ run_coro_as_sync(engine.wait_until_ready())
611
+ with engine.run_context(), engine.transaction_context() as txn:
612
+ engine.call_task_fn(txn)
613
+
614
+ return engine.state if return_type == "state" else engine.result()
599
615
 
600
616
 
601
617
  async def run_task_async(
602
- task: Task[P, Coroutine[Any, Any, R]],
618
+ task: Task[P, R],
603
619
  task_run_id: Optional[UUID] = None,
604
620
  task_run: Optional[TaskRun] = None,
605
621
  parameters: Optional[Dict[str, Any]] = None,
@@ -608,12 +624,6 @@ async def run_task_async(
608
624
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
609
625
  context: Optional[Dict[str, Any]] = None,
610
626
  ) -> Union[R, State, None]:
611
- """
612
- Runs a task against the API.
613
-
614
- We will most likely want to use this logic as a wrapper and return a coroutine for type inference.
615
- """
616
-
617
627
  engine = TaskRunEngine[P, R](
618
628
  task=task,
619
629
  parameters=parameters,
@@ -621,53 +631,14 @@ async def run_task_async(
621
631
  wait_for=wait_for,
622
632
  context=context,
623
633
  )
624
- # This is a context manager that keeps track of the run of the task run.
625
- with engine.start(task_run_id=task_run_id, dependencies=dependencies) as run:
626
- with run.enter_run_context():
627
- run.begin_run()
628
-
629
- while run.is_running():
630
- # enter run context on each loop iteration to ensure the context
631
- # contains the latest task run metadata
632
- with run.enter_run_context():
633
- try:
634
- # This is where the task is actually run.
635
- with timeout_async(seconds=run.task.timeout_seconds):
636
- call_args, call_kwargs = parameters_to_args_kwargs(
637
- task.fn, run.parameters or {}
638
- )
639
- run.logger.debug(
640
- f"Executing task {task.name!r} for task run {run.task_run.name!r}..."
641
- )
642
- result_factory = getattr(
643
- TaskRunContext.get(), "result_factory", None
644
- )
645
- with transaction(
646
- key=run.compute_transaction_key(),
647
- store=ResultFactoryStore(result_factory=result_factory),
648
- ) as txn:
649
- if txn.is_committed():
650
- result = txn.read()
651
- else:
652
- result = await task.fn(*call_args, **call_kwargs) # type: ignore
653
-
654
- # If the task run is successful, finalize it.
655
- # do this within the transaction lifecycle
656
- # in order to get the proper result serialization
657
- run.handle_success(result, transaction=txn)
658
-
659
- except TimeoutError as exc:
660
- run.handle_timeout(exc)
661
- except Exception as exc:
662
- run.handle_exception(exc)
663
-
664
- if run.state.is_final():
665
- for hook in run.get_hooks(run.state, as_async=True):
666
- await hook()
667
-
668
- if return_type == "state":
669
- return run.state
670
- return run.result()
634
+
635
+ with engine.start(task_run_id=task_run_id, dependencies=dependencies):
636
+ while engine.is_running():
637
+ await engine.wait_until_ready()
638
+ with engine.run_context(), engine.transaction_context() as txn:
639
+ await engine.call_task_fn(txn)
640
+
641
+ return engine.state if return_type == "state" else engine.result()
671
642
 
672
643
 
673
644
  def run_task(
prefect/task_runners.py CHANGED
@@ -288,6 +288,10 @@ class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture]):
288
288
  super().__exit__(exc_type, exc_value, traceback)
289
289
 
290
290
 
291
+ # Here, we alias ConcurrentTaskRunner to ThreadPoolTaskRunner for backwards compatibility
292
+ ConcurrentTaskRunner = ThreadPoolTaskRunner
293
+
294
+
291
295
  class PrefectTaskRunner(TaskRunner[PrefectDistributedFuture]):
292
296
  def __init__(self):
293
297
  super().__init__()
@@ -321,11 +325,11 @@ class PrefectTaskRunner(TaskRunner[PrefectDistributedFuture]):
321
325
  flow_run_ctx = FlowRunContext.get()
322
326
  if flow_run_ctx:
323
327
  get_run_logger(flow_run_ctx).info(
324
- f"Submitting task {task.name} to for execution by a Prefect task server..."
328
+ f"Submitting task {task.name} to for execution by a Prefect task worker..."
325
329
  )
326
330
  else:
327
331
  self.logger.info(
328
- f"Submitting task {task.name} to for execution by a Prefect task server..."
332
+ f"Submitting task {task.name} to for execution by a Prefect task worker..."
329
333
  )
330
334
 
331
335
  return task.apply_async(