prefect-client 2.18.1__py3-none-any.whl → 2.18.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,10 @@
1
1
  import asyncio
2
+ import inspect
2
3
  import logging
3
- from contextlib import asynccontextmanager
4
+ from contextlib import asynccontextmanager, contextmanager
4
5
  from dataclasses import dataclass, field
5
6
  from typing import (
6
7
  Any,
7
- AsyncGenerator,
8
8
  Callable,
9
9
  Coroutine,
10
10
  Dict,
@@ -22,6 +22,11 @@ import pendulum
22
22
  from typing_extensions import ParamSpec
23
23
 
24
24
  from prefect import Task, get_client
25
+ from prefect._internal.concurrency.cancellation import (
26
+ AlarmCancelScope,
27
+ AsyncCancelScope,
28
+ CancelledError,
29
+ )
25
30
  from prefect.client.orchestration import PrefectClient
26
31
  from prefect.client.schemas import TaskRun
27
32
  from prefect.client.schemas.objects import TaskRunResult
@@ -40,7 +45,8 @@ from prefect.states import (
40
45
  exception_to_failed_state,
41
46
  return_value_to_state,
42
47
  )
43
- from prefect.utilities.asyncutils import A, Async, is_async_fn
48
+ from prefect.utilities.asyncutils import A, Async, is_async_fn, run_sync
49
+ from prefect.utilities.callables import parameters_to_args_kwargs
44
50
  from prefect.utilities.engine import (
45
51
  _dynamic_key_for_task_run,
46
52
  _get_hook_name,
@@ -49,32 +55,31 @@ from prefect.utilities.engine import (
49
55
  propose_state,
50
56
  )
51
57
 
58
+ P = ParamSpec("P")
59
+ R = TypeVar("R")
52
60
 
53
- @asynccontextmanager
54
- async def timeout(
55
- delay: Optional[float], *, loop: Optional[asyncio.AbstractEventLoop] = None
56
- ) -> AsyncGenerator[None, None]:
57
- loop = loop or asyncio.get_running_loop()
58
- task = asyncio.current_task(loop=loop)
59
- timer_handle: Optional[asyncio.TimerHandle] = None
60
-
61
- if delay is not None and task is not None:
62
- timer_handle = loop.call_later(delay, task.cancel)
63
61
 
62
+ @asynccontextmanager
63
+ async def timeout(seconds: Optional[float] = None):
64
64
  try:
65
- yield
66
- finally:
67
- if timer_handle is not None:
68
- timer_handle.cancel()
65
+ with AsyncCancelScope(timeout=seconds):
66
+ yield
67
+ except CancelledError:
68
+ raise TimeoutError(f"Task timed out after {seconds} second(s).")
69
69
 
70
70
 
71
- P = ParamSpec("P")
72
- R = TypeVar("R")
71
+ @contextmanager
72
+ def timeout_sync(seconds: Optional[float] = None):
73
+ try:
74
+ with AlarmCancelScope(timeout=seconds):
75
+ yield
76
+ except CancelledError:
77
+ raise TimeoutError(f"Task timed out after {seconds} second(s).")
73
78
 
74
79
 
75
80
  @dataclass
76
81
  class TaskRunEngine(Generic[P, R]):
77
- task: Task[P, Coroutine[Any, Any, R]]
82
+ task: Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]
78
83
  logger: logging.Logger = field(default_factory=lambda: get_logger("engine"))
79
84
  parameters: Optional[Dict[str, Any]] = None
80
85
  task_run: Optional[TaskRun] = None
@@ -94,16 +99,20 @@ class TaskRunEngine(Generic[P, R]):
94
99
 
95
100
  @property
96
101
  def state(self) -> State:
97
- return self.task_run.state # type: ignore
102
+ if not self.task_run:
103
+ raise ValueError("Task run is not set")
104
+ return self.task_run.state
98
105
 
99
106
  @property
100
107
  def can_retry(self) -> bool:
101
- retry_condition: Optional[ # type: ignore
108
+ retry_condition: Optional[
102
109
  Callable[[Task[P, Coroutine[Any, Any, R]], TaskRun, State], bool]
103
- ] = self.task.retry_condition_fn # type: ignore
110
+ ] = self.task.retry_condition_fn
111
+ if not self.task_run:
112
+ raise ValueError("Task run is not set")
104
113
  return not retry_condition or retry_condition(
105
114
  self.task, self.task_run, self.state
106
- ) # type: ignore
115
+ )
107
116
 
108
117
  async def _run_hooks(self, state: State) -> None:
109
118
  """Run the on_failure and on_completion hooks for a task, making sure to
@@ -112,6 +121,9 @@ class TaskRunEngine(Generic[P, R]):
112
121
  task = self.task
113
122
  task_run = self.task_run
114
123
 
124
+ if not task_run:
125
+ raise ValueError("Task run is not set")
126
+
115
127
  hooks = None
116
128
  if state.is_failed() and task.on_failure:
117
129
  hooks = task.on_failure
@@ -127,9 +139,9 @@ class TaskRunEngine(Generic[P, R]):
127
139
  f" {state.name!r}"
128
140
  )
129
141
  if is_async_fn(hook):
130
- await hook(task=task, task_run=task_run, state=state)
142
+ await hook(task, task_run, state)
131
143
  else:
132
- hook(task=task, task_run=task_run, state=state)
144
+ hook(task, task_run, state)
133
145
  except Exception:
134
146
  self.logger.error(
135
147
  f"An error was encountered while running hook {hook_name!r}",
@@ -148,7 +160,7 @@ class TaskRunEngine(Generic[P, R]):
148
160
  cache_key = (
149
161
  self.task.cache_key_fn(
150
162
  task_run_context,
151
- self.parameters,
163
+ self.parameters or {},
152
164
  )
153
165
  if self.task.cache_key_fn
154
166
  else None
@@ -175,7 +187,7 @@ class TaskRunEngine(Generic[P, R]):
175
187
  cache_expiration=cache_expiration,
176
188
  )
177
189
 
178
- async def begin_run(self) -> State:
190
+ async def begin_run(self):
179
191
  state_details = self._compute_state_details()
180
192
  new_state = Running(state_details=state_details)
181
193
  state = await self.set_state(new_state)
@@ -184,6 +196,8 @@ class TaskRunEngine(Generic[P, R]):
184
196
  state = await self.set_state(new_state)
185
197
 
186
198
  async def set_state(self, state: State, force: bool = False) -> State:
199
+ if not self.task_run:
200
+ raise ValueError("Task run is not set")
187
201
  new_state = await propose_state(
188
202
  self.client, state, task_run_id=self.task_run.id, force=force
189
203
  ) # type: ignore
@@ -197,10 +211,17 @@ class TaskRunEngine(Generic[P, R]):
197
211
  return new_state
198
212
 
199
213
  async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
200
- return await self.state.result(raise_on_failure=raise_on_failure)
214
+ _result = self.state.result(raise_on_failure=raise_on_failure, fetch=True)
215
+ # state.result is a `sync_compatible` function that may or may not return an awaitable
216
+ # depending on whether the parent frame is sync or not
217
+ if inspect.isawaitable(_result):
218
+ _result = await _result
219
+ return _result
201
220
 
202
221
  async def handle_success(self, result: R) -> R:
203
222
  result_factory = getattr(TaskRunContext.get(), "result_factory", None)
223
+ if result_factory is None:
224
+ raise ValueError("Result factory is not set")
204
225
  terminal_state = await return_value_to_state(
205
226
  await resolve_futures_to_states(result),
206
227
  result_factory=result_factory,
@@ -243,20 +264,21 @@ class TaskRunEngine(Generic[P, R]):
243
264
 
244
265
  async def create_task_run(self, client: PrefectClient) -> TaskRun:
245
266
  flow_run_ctx = FlowRunContext.get()
267
+ parameters = self.parameters or {}
246
268
  try:
247
- task_run_name = _resolve_custom_task_run_name(self.task, self.parameters)
269
+ task_run_name = _resolve_custom_task_run_name(self.task, parameters)
248
270
  except TypeError:
249
271
  task_run_name = None
250
272
 
251
273
  # prep input tracking
252
274
  task_inputs = {
253
- k: await collect_task_run_inputs(v) for k, v in self.parameters.items()
275
+ k: await collect_task_run_inputs(v) for k, v in parameters.items()
254
276
  }
255
277
 
256
278
  # anticipate nested runs
257
279
  task_run_ctx = TaskRunContext.get()
258
280
  if task_run_ctx:
259
- task_inputs["wait_for"] = [TaskRunResult(id=task_run_ctx.task_run.id)]
281
+ task_inputs["wait_for"] = [TaskRunResult(id=task_run_ctx.task_run.id)] # type: ignore
260
282
 
261
283
  # TODO: implement wait_for
262
284
  # if wait_for:
@@ -269,16 +291,16 @@ class TaskRunEngine(Generic[P, R]):
269
291
  else:
270
292
  dynamic_key = uuid4().hex
271
293
  task_run = await client.create_task_run(
272
- task=self.task,
294
+ task=self.task, # type: ignore
273
295
  name=task_run_name,
274
296
  flow_run_id=(
275
297
  getattr(flow_run_ctx.flow_run, "id", None)
276
298
  if flow_run_ctx and flow_run_ctx.flow_run
277
299
  else None
278
300
  ),
279
- dynamic_key=dynamic_key,
301
+ dynamic_key=str(dynamic_key),
280
302
  state=Pending(),
281
- task_inputs=task_inputs,
303
+ task_inputs=task_inputs, # type: ignore
282
304
  )
283
305
  return task_run
284
306
 
@@ -287,6 +309,9 @@ class TaskRunEngine(Generic[P, R]):
287
309
  if client is None:
288
310
  client = self.client
289
311
 
312
+ if not self.task_run:
313
+ raise ValueError("Task run is not set")
314
+
290
315
  self.task_run = await client.read_task_run(self.task_run.id)
291
316
 
292
317
  with TaskRunContext(
@@ -294,10 +319,30 @@ class TaskRunEngine(Generic[P, R]):
294
319
  log_prints=self.task.log_prints or False,
295
320
  task_run=self.task_run,
296
321
  parameters=self.parameters,
297
- result_factory=await ResultFactory.from_autonomous_task(self.task),
322
+ result_factory=await ResultFactory.from_autonomous_task(self.task), # type: ignore
298
323
  client=client,
299
324
  ):
300
- self.logger = task_run_logger(task_run=self.task_run, task=self.task)
325
+ self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
326
+ yield
327
+
328
+ @contextmanager
329
+ def enter_run_context_sync(self, client: Optional[PrefectClient] = None):
330
+ if client is None:
331
+ client = self.client
332
+ if not self.task_run:
333
+ raise ValueError("Task run is not set")
334
+
335
+ self.task_run = run_sync(client.read_task_run(self.task_run.id))
336
+
337
+ with TaskRunContext(
338
+ task=self.task,
339
+ log_prints=self.task.log_prints or False,
340
+ task_run=self.task_run,
341
+ parameters=self.parameters,
342
+ result_factory=run_sync(ResultFactory.from_autonomous_task(self.task)), # type: ignore
343
+ client=client,
344
+ ):
345
+ self.logger = task_run_logger(task_run=self.task_run, task=self.task) # type: ignore
301
346
  yield
302
347
 
303
348
  @asynccontextmanager
@@ -311,7 +356,6 @@ class TaskRunEngine(Generic[P, R]):
311
356
  try:
312
357
  if not self.task_run:
313
358
  self.task_run = await self.create_task_run(client)
314
-
315
359
  yield self
316
360
  except Exception:
317
361
  # regular exceptions are caught and re-raised to the user
@@ -324,6 +368,38 @@ class TaskRunEngine(Generic[P, R]):
324
368
  self._is_started = False
325
369
  self._client = None
326
370
 
371
+ @contextmanager
372
+ def start_sync(self):
373
+ """
374
+ Enters a client context and creates a task run if needed.
375
+ """
376
+ client = get_client()
377
+ run_sync(client.__aenter__())
378
+ self._client = client
379
+ self._is_started = True
380
+ try:
381
+ if not self.task_run:
382
+ self.task_run = run_sync(self.create_task_run(client))
383
+ yield self
384
+ except Exception:
385
+ # regular exceptions are caught and re-raised to the user
386
+ raise
387
+ except BaseException as exc:
388
+ # BaseExceptions are caught and handled as crashes
389
+ run_sync(self.handle_crash(exc))
390
+ raise
391
+ finally:
392
+ # quickly close client
393
+ run_sync(client.__aexit__(None, None, None))
394
+ self._is_started = False
395
+ self._client = None
396
+
397
+ async def get_client(self):
398
+ if not self._is_started:
399
+ raise RuntimeError("Engine has not started.")
400
+ else:
401
+ return self._client
402
+
327
403
  def is_running(self) -> bool:
328
404
  if getattr(self, "task_run", None) is None:
329
405
  return False
@@ -341,26 +417,28 @@ async def run_task(
341
417
  parameters: Optional[Dict[str, Any]] = None,
342
418
  wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
343
419
  return_type: Literal["state", "result"] = "result",
344
- ) -> "Union[R, State, None]":
420
+ ) -> Union[R, State, None]:
345
421
  """
346
422
  Runs a task against the API.
347
423
 
348
424
  We will most likely want to use this logic as a wrapper and return a coroutine for type inference.
349
425
  """
350
426
  engine = TaskRunEngine[P, R](task=task, parameters=parameters, task_run=task_run)
427
+
428
+ # This is a context manager that keeps track of the run of the task run.
351
429
  async with engine.start() as run:
352
- # This is a context manager that keeps track of the run of the task run.
353
430
  await run.begin_run()
354
431
 
355
432
  while run.is_running():
356
433
  async with run.enter_run_context():
357
434
  try:
358
435
  # This is where the task is actually run.
359
- async with timeout(run.task.timeout_seconds):
360
- if task.isasync:
361
- result = cast(R, await task.fn(**(parameters or {}))) # type: ignore
362
- else:
363
- result = cast(R, task.fn(**(parameters or {}))) # type: ignore
436
+ async with timeout(seconds=run.task.timeout_seconds):
437
+ call_args, call_kwargs = parameters_to_args_kwargs(
438
+ task.fn, run.parameters or {}
439
+ )
440
+ result = cast(R, await task.fn(*call_args, **call_kwargs)) # type: ignore
441
+
364
442
  # If the task run is successful, finalize it.
365
443
  await run.handle_success(result)
366
444
  if return_type == "result":
@@ -372,3 +450,39 @@ async def run_task(
372
450
  if return_type == "state":
373
451
  return run.state
374
452
  return await run.result()
453
+
454
+
455
+ def run_task_sync(
456
+ task: Task[P, R],
457
+ task_run: Optional[TaskRun] = None,
458
+ parameters: Optional[Dict[str, Any]] = None,
459
+ wait_for: Optional[Iterable[PrefectFuture[A, Async]]] = None,
460
+ return_type: Literal["state", "result"] = "result",
461
+ ) -> Union[R, State, None]:
462
+ engine = TaskRunEngine[P, R](task=task, parameters=parameters, task_run=task_run)
463
+
464
+ # This is a context manager that keeps track of the run of the task run.
465
+ with engine.start_sync() as run:
466
+ run_sync(run.begin_run())
467
+
468
+ while run.is_running():
469
+ with run.enter_run_context_sync():
470
+ try:
471
+ # This is where the task is actually run.
472
+ with timeout_sync(seconds=run.task.timeout_seconds):
473
+ call_args, call_kwargs = parameters_to_args_kwargs(
474
+ task.fn, run.parameters or {}
475
+ )
476
+ result = cast(R, task.fn(*call_args, **call_kwargs)) # type: ignore
477
+
478
+ # If the task run is successful, finalize it.
479
+ run_sync(run.handle_success(result))
480
+ if return_type == "result":
481
+ return result
482
+
483
+ except Exception as exc:
484
+ run_sync(run.handle_exception(exc))
485
+
486
+ if return_type == "state":
487
+ return run.state
488
+ return run_sync(run.result())
@@ -64,7 +64,7 @@
64
64
  }
65
65
  },
66
66
  "description": "Execute flow runs as subprocesses on a worker. Works well for local execution when first getting started.",
67
- "display_name": "Local Subprocess",
67
+ "display_name": "Process",
68
68
  "documentation_url": "https://docs.prefect.io/latest/api-ref/prefect/workers/process/",
69
69
  "install_command": "pip install prefect",
70
70
  "is_beta": false,
prefect/settings.py CHANGED
@@ -1221,12 +1221,28 @@ PREFECT_API_SERVICES_FOREMAN_LOOP_SECONDS = Setting(float, default=15)
1221
1221
  """The number of seconds to wait between each iteration of the Foreman loop which checks
1222
1222
  for offline workers and updates work pool status."""
1223
1223
 
1224
+
1225
+ PREFECT_API_SERVICES_FOREMAN_INACTIVITY_HEARTBEAT_MULTIPLE = Setting(int, default=3)
1226
+ "The number of heartbeats that must be missed before a worker is marked as offline."
1227
+
1228
+ PREFECT_API_SERVICES_FOREMAN_FALLBACK_HEARTBEAT_INTERVAL_SECONDS = Setting(
1229
+ int, default=30
1230
+ )
1231
+ """The number of seconds to use for online/offline evaluation if a worker's heartbeat
1232
+ interval is not set."""
1233
+
1224
1234
  PREFECT_API_SERVICES_FOREMAN_DEPLOYMENT_LAST_POLLED_TIMEOUT_SECONDS = Setting(
1225
1235
  int, default=60
1226
1236
  )
1227
1237
  """The number of seconds before a deployment is marked as not ready if it has not been
1228
1238
  polled."""
1229
1239
 
1240
+ PREFECT_API_SERVICES_FOREMAN_WORK_QUEUE_LAST_POLLED_TIMEOUT_SECONDS = Setting(
1241
+ int, default=60
1242
+ )
1243
+ """The number of seconds before a work queue is marked as not ready if it has not been
1244
+ polled."""
1245
+
1230
1246
 
1231
1247
  PREFECT_API_DEFAULT_LIMIT = Setting(
1232
1248
  int,
@@ -1712,6 +1728,11 @@ PREFECT_API_SERVICES_EVENT_PERSISTER_FLUSH_INTERVAL = Setting(float, default=5,
1712
1728
  The maximum number of seconds between flushes of the event persister.
1713
1729
  """
1714
1730
 
1731
+ PREFECT_EVENTS_RETENTION_PERIOD = Setting(timedelta, default=timedelta(days=7))
1732
+ """
1733
+ The amount of time to retain events in the database.
1734
+ """
1735
+
1715
1736
  PREFECT_API_EVENTS_STREAM_OUT_ENABLED = Setting(bool, default=True)
1716
1737
  """
1717
1738
  Whether or not to allow streaming events out of via websockets.
prefect/tasks.py CHANGED
@@ -22,18 +22,22 @@ from typing import (
22
22
  List,
23
23
  NoReturn,
24
24
  Optional,
25
+ Set,
25
26
  TypeVar,
26
27
  Union,
27
28
  cast,
28
29
  overload,
29
30
  )
31
+ from uuid import uuid4
30
32
 
31
33
  from typing_extensions import Literal, ParamSpec
32
34
 
33
35
  from prefect._internal.concurrency.api import create_call, from_async, from_sync
34
36
  from prefect.client.schemas import TaskRun
35
- from prefect.context import FlowRunContext, PrefectObjectRegistry
37
+ from prefect.client.schemas.objects import TaskRunInput
38
+ from prefect.context import FlowRunContext, PrefectObjectRegistry, TagsContext
36
39
  from prefect.futures import PrefectFuture
40
+ from prefect.logging.loggers import get_logger, get_run_logger
37
41
  from prefect.results import ResultSerializer, ResultStorage
38
42
  from prefect.settings import (
39
43
  PREFECT_EXPERIMENTAL_ENABLE_NEW_ENGINE,
@@ -41,7 +45,7 @@ from prefect.settings import (
41
45
  PREFECT_TASK_DEFAULT_RETRIES,
42
46
  PREFECT_TASK_DEFAULT_RETRY_DELAY_SECONDS,
43
47
  )
44
- from prefect.states import State
48
+ from prefect.states import Pending, State
45
49
  from prefect.task_runners import BaseTaskRunner
46
50
  from prefect.utilities.annotations import NotSet
47
51
  from prefect.utilities.asyncutils import Async, Sync
@@ -65,6 +69,8 @@ T = TypeVar("T") # Generic type var for capturing the inner return type of asyn
65
69
  R = TypeVar("R") # The return type of the user's function
66
70
  P = ParamSpec("P") # The parameters of the task
67
71
 
72
+ logger = get_logger("tasks")
73
+
68
74
 
69
75
  def task_input_hash(
70
76
  context: "TaskRunContext", arguments: Dict[str, Any]
@@ -190,14 +196,14 @@ class Task(Generic[P, R]):
190
196
  def __init__(
191
197
  self,
192
198
  fn: Callable[P, R],
193
- name: str = None,
194
- description: str = None,
195
- tags: Iterable[str] = None,
196
- version: str = None,
197
- cache_key_fn: Callable[
198
- ["TaskRunContext", Dict[str, Any]], Optional[str]
199
+ name: Optional[str] = None,
200
+ description: Optional[str] = None,
201
+ tags: Optional[Iterable[str]] = None,
202
+ version: Optional[str] = None,
203
+ cache_key_fn: Optional[
204
+ Callable[["TaskRunContext", Dict[str, Any]], Optional[str]]
199
205
  ] = None,
200
- cache_expiration: datetime.timedelta = None,
206
+ cache_expiration: Optional[datetime.timedelta] = None,
201
207
  task_run_name: Optional[Union[Callable[[], str], str]] = None,
202
208
  retries: Optional[int] = None,
203
209
  retry_delay_seconds: Optional[
@@ -214,7 +220,7 @@ class Task(Generic[P, R]):
214
220
  result_serializer: Optional[ResultSerializer] = None,
215
221
  result_storage_key: Optional[str] = None,
216
222
  cache_result_in_memory: bool = True,
217
- timeout_seconds: Union[int, float] = None,
223
+ timeout_seconds: Union[int, float, None] = None,
218
224
  log_prints: Optional[bool] = False,
219
225
  refresh_cache: Optional[bool] = None,
220
226
  on_completion: Optional[List[Callable[["Task", TaskRun, State], None]]] = None,
@@ -326,6 +332,7 @@ class Task(Generic[P, R]):
326
332
  self.result_serializer = result_serializer
327
333
  self.result_storage_key = result_storage_key
328
334
  self.cache_result_in_memory = cache_result_in_memory
335
+
329
336
  self.timeout_seconds = float(timeout_seconds) if timeout_seconds else None
330
337
  # Warn if this task's `name` conflicts with another task while having a
331
338
  # different function. This is to detect the case where two or more tasks
@@ -530,6 +537,53 @@ class Task(Generic[P, R]):
530
537
  viz_return_value=viz_return_value or self.viz_return_value,
531
538
  )
532
539
 
540
+ async def create_run(
541
+ self,
542
+ flow_run_context: FlowRunContext,
543
+ parameters: Dict[str, Any],
544
+ wait_for: Optional[Iterable[PrefectFuture]],
545
+ extra_task_inputs: Optional[Dict[str, Set[TaskRunInput]]] = None,
546
+ ) -> TaskRun:
547
+ # TODO: Investigate if we can replace create_task_run on the task run engine
548
+ # with this method. Would require updating to work without the flow run context.
549
+ from prefect.utilities.engine import (
550
+ _dynamic_key_for_task_run,
551
+ collect_task_run_inputs,
552
+ )
553
+
554
+ dynamic_key = _dynamic_key_for_task_run(flow_run_context, self)
555
+ task_inputs = {
556
+ k: await collect_task_run_inputs(v) for k, v in parameters.items()
557
+ }
558
+ if wait_for:
559
+ task_inputs["wait_for"] = await collect_task_run_inputs(wait_for)
560
+
561
+ # Join extra task inputs
562
+ extra_task_inputs = extra_task_inputs or {}
563
+ for k, extras in extra_task_inputs.items():
564
+ task_inputs[k] = task_inputs[k].union(extras)
565
+
566
+ flow_run_logger = get_run_logger(flow_run_context)
567
+
568
+ task_run = await flow_run_context.client.create_task_run(
569
+ task=self,
570
+ name=f"{self.name} - {dynamic_key}",
571
+ flow_run_id=flow_run_context.flow_run.id,
572
+ dynamic_key=dynamic_key,
573
+ state=Pending(),
574
+ extra_tags=TagsContext.get().current_tags,
575
+ task_inputs=task_inputs,
576
+ )
577
+
578
+ if flow_run_context.flow_run:
579
+ flow_run_logger.info(
580
+ f"Created task run {task_run.name!r} for task {self.name!r}"
581
+ )
582
+ else:
583
+ logger.info(f"Created task run {task_run.name!r} for task {self.name!r}")
584
+
585
+ return task_run
586
+
533
587
  @overload
534
588
  def __call__(
535
589
  self: "Task[P, NoReturn]",
@@ -585,19 +639,19 @@ class Task(Generic[P, R]):
585
639
 
586
640
  # new engine currently only compatible with async tasks
587
641
  if PREFECT_EXPERIMENTAL_ENABLE_NEW_ENGINE.value():
588
- from prefect.new_task_engine import run_task
589
- from prefect.utilities.asyncutils import run_sync
642
+ from prefect.new_task_engine import run_task, run_task_sync
590
643
 
591
- awaitable = run_task(
644
+ run_kwargs = dict(
592
645
  task=self,
593
646
  parameters=parameters,
594
647
  wait_for=wait_for,
595
648
  return_type=return_type,
596
649
  )
597
650
  if self.isasync:
598
- return awaitable
651
+ # this returns an awaitable coroutine
652
+ return run_task(**run_kwargs)
599
653
  else:
600
- return run_sync(awaitable)
654
+ return run_task_sync(**run_kwargs)
601
655
 
602
656
  if (
603
657
  PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING.value()
@@ -827,6 +881,7 @@ class Task(Generic[P, R]):
827
881
  # Convert the call args/kwargs to a parameter dict
828
882
  parameters = get_call_parameters(self.fn, args, kwargs)
829
883
  return_type = "state" if return_state else "future"
884
+ flow_run_context = FlowRunContext.get()
830
885
 
831
886
  task_viz_tracker = get_task_viz_tracker()
832
887
  if task_viz_tracker:
@@ -834,10 +889,7 @@ class Task(Generic[P, R]):
834
889
  "`task.submit()` is not currently supported by `flow.visualize()`"
835
890
  )
836
891
 
837
- if (
838
- PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING.value()
839
- and not FlowRunContext.get()
840
- ):
892
+ if PREFECT_EXPERIMENTAL_ENABLE_TASK_SCHEDULING and not flow_run_context:
841
893
  create_autonomous_task_run_call = create_call(
842
894
  create_autonomous_task_run, task=self, parameters=parameters
843
895
  )
@@ -849,16 +901,74 @@ class Task(Generic[P, R]):
849
901
  return from_sync.wait_for_call_in_loop_thread(
850
902
  create_autonomous_task_run_call
851
903
  )
904
+ if PREFECT_EXPERIMENTAL_ENABLE_NEW_ENGINE and flow_run_context:
905
+ if self.isasync:
906
+ return self._submit_async(
907
+ parameters=parameters,
908
+ flow_run_context=flow_run_context,
909
+ wait_for=wait_for,
910
+ return_state=return_state,
911
+ )
912
+ else:
913
+ raise NotImplementedError(
914
+ "Submitting sync tasks with the new engine has not be implemented yet."
915
+ )
852
916
 
853
- return enter_task_run_engine(
854
- self,
917
+ else:
918
+ return enter_task_run_engine(
919
+ self,
920
+ parameters=parameters,
921
+ wait_for=wait_for,
922
+ return_type=return_type,
923
+ task_runner=None, # Use the flow's task runner
924
+ mapped=False,
925
+ )
926
+
927
+ async def _submit_async(
928
+ self,
929
+ parameters: Dict[str, Any],
930
+ flow_run_context: FlowRunContext,
931
+ wait_for: Optional[Iterable[PrefectFuture]],
932
+ return_state: bool,
933
+ ):
934
+ from prefect.new_task_engine import run_task
935
+
936
+ task_runner = flow_run_context.task_runner
937
+
938
+ task_run = await self.create_run(
939
+ flow_run_context=flow_run_context,
855
940
  parameters=parameters,
856
941
  wait_for=wait_for,
857
- return_type=return_type,
858
- task_runner=None, # Use the flow's task runner
859
- mapped=False,
860
942
  )
861
943
 
944
+ future = PrefectFuture(
945
+ name=task_run.name,
946
+ key=uuid4(),
947
+ task_runner=task_runner,
948
+ asynchronous=(self.isasync and flow_run_context.flow.isasync),
949
+ )
950
+ future.task_run = task_run
951
+ flow_run_context.task_run_futures.append(future)
952
+ await task_runner.submit(
953
+ key=future.key,
954
+ call=partial(
955
+ run_task,
956
+ task=self,
957
+ task_run=task_run,
958
+ parameters=parameters,
959
+ wait_for=wait_for,
960
+ return_type="state",
961
+ ),
962
+ )
963
+ # TODO: I don't like this. Can we move responsibility for creating the future
964
+ # and setting this anyio.Event to the task runner?
965
+ future._submitted.set()
966
+
967
+ if return_state:
968
+ return await future.wait()
969
+ else:
970
+ return future
971
+
862
972
  @overload
863
973
  def map(
864
974
  self: "Task[P, NoReturn]",