prefect-client 3.2.13__py3-none-any.whl → 3.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. prefect/_build_info.py +3 -3
  2. prefect/_experimental/bundles.py +7 -1
  3. prefect/_internal/concurrency/services.py +13 -3
  4. prefect/cache_policies.py +31 -2
  5. prefect/client/orchestration/_flow_runs/client.py +34 -4
  6. prefect/client/schemas/actions.py +14 -1
  7. prefect/client/schemas/objects.py +18 -0
  8. prefect/deployments/runner.py +1 -9
  9. prefect/docker/docker_image.py +2 -1
  10. prefect/flow_engine.py +11 -5
  11. prefect/flow_runs.py +1 -1
  12. prefect/flows.py +27 -9
  13. prefect/locking/memory.py +16 -8
  14. prefect/logging/__init__.py +1 -1
  15. prefect/logging/configuration.py +6 -4
  16. prefect/logging/formatters.py +3 -3
  17. prefect/logging/handlers.py +37 -26
  18. prefect/results.py +9 -3
  19. prefect/runner/__init__.py +2 -0
  20. prefect/runner/runner.py +1 -1
  21. prefect/runner/server.py +12 -7
  22. prefect/runner/storage.py +37 -37
  23. prefect/runner/submit.py +36 -25
  24. prefect/runner/utils.py +9 -5
  25. prefect/server/api/collections_data/views/aggregate-worker-metadata.json +4 -4
  26. prefect/server/api/flow_runs.py +21 -0
  27. prefect/server/api/task_runs.py +52 -1
  28. prefect/settings/models/tasks.py +5 -0
  29. prefect/task_engine.py +18 -24
  30. prefect/tasks.py +31 -8
  31. prefect/transactions.py +5 -0
  32. prefect/utilities/callables.py +2 -0
  33. prefect/utilities/engine.py +2 -2
  34. prefect/utilities/importtools.py +6 -9
  35. prefect/workers/base.py +9 -4
  36. {prefect_client-3.2.13.dist-info → prefect_client-3.2.15.dist-info}/METADATA +3 -2
  37. {prefect_client-3.2.13.dist-info → prefect_client-3.2.15.dist-info}/RECORD +39 -39
  38. {prefect_client-3.2.13.dist-info → prefect_client-3.2.15.dist-info}/WHEEL +0 -0
  39. {prefect_client-3.2.13.dist-info → prefect_client-3.2.15.dist-info}/licenses/LICENSE +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import inspect
3
4
  import json
4
5
  import logging
5
6
  import sys
@@ -44,10 +45,13 @@ else:
44
45
  else:
45
46
  StreamHandler = logging.StreamHandler
46
47
 
48
+ if TYPE_CHECKING:
49
+ from prefect.client.schemas.objects import FlowRun, TaskRun
50
+
47
51
 
48
52
  class APILogWorker(BatchedQueueService[Dict[str, Any]]):
49
53
  @property
50
- def _max_batch_size(self) -> int:
54
+ def max_batch_size(self) -> int:
51
55
  return max(
52
56
  PREFECT_LOGGING_TO_API_BATCH_SIZE.value()
53
57
  - PREFECT_LOGGING_TO_API_MAX_LOG_SIZE.value(),
@@ -55,7 +59,7 @@ class APILogWorker(BatchedQueueService[Dict[str, Any]]):
55
59
  )
56
60
 
57
61
  @property
58
- def _min_interval(self) -> float | None:
62
+ def min_interval(self) -> float | None:
59
63
  return PREFECT_LOGGING_TO_API_BATCH_INTERVAL.value()
60
64
 
61
65
  async def _handle_batch(self, items: list[dict[str, Any]]):
@@ -77,7 +81,7 @@ class APILogWorker(BatchedQueueService[Dict[str, Any]]):
77
81
  yield
78
82
 
79
83
  @classmethod
80
- def instance(cls: Type[Self]) -> Self:
84
+ def instance(cls: Type[Self], *args: Any) -> Self:
81
85
  settings = (
82
86
  PREFECT_LOGGING_TO_API_BATCH_SIZE.value(),
83
87
  PREFECT_API_URL.value(),
@@ -85,7 +89,7 @@ class APILogWorker(BatchedQueueService[Dict[str, Any]]):
85
89
  )
86
90
 
87
91
  # Ensure a unique worker is retrieved per relevant logging settings
88
- return super().instance(*settings)
92
+ return super().instance(*settings, *args)
89
93
 
90
94
  def _get_size(self, item: Dict[str, Any]) -> int:
91
95
  return item.pop("__payload_size__", None) or len(json.dumps(item).encode())
@@ -99,8 +103,7 @@ class APILogHandler(logging.Handler):
99
103
  the background.
100
104
  """
101
105
 
102
- @classmethod
103
- def flush(cls) -> None:
106
+ def flush(self) -> None:
104
107
  """
105
108
  Tell the `APILogWorker` to send any currently enqueued logs and block until
106
109
  completion.
@@ -119,22 +122,23 @@ class APILogHandler(logging.Handler):
119
122
  # Not ideal, but this method is called by the stdlib and cannot return a
120
123
  # coroutine so we just schedule the drain in a new thread and continue
121
124
  from_sync.call_soon_in_new_thread(create_call(APILogWorker.drain_all))
122
- return None
123
125
  else:
124
126
  # We set a timeout of 5s because we don't want to block forever if the worker
125
127
  # is stuck. This can occur when the handler is being shutdown and the
126
128
  # `logging._lock` is held but the worker is attempting to emit logs resulting
127
129
  # in a deadlock.
128
- return APILogWorker.drain_all(timeout=5)
130
+ APILogWorker.drain_all(timeout=5)
129
131
 
130
132
  @classmethod
131
- async def aflush(cls) -> bool:
133
+ async def aflush(cls) -> None:
132
134
  """
133
135
  Tell the `APILogWorker` to send any currently enqueued logs and block until
134
136
  completion.
135
137
  """
136
138
 
137
- return await APILogWorker.drain_all()
139
+ result = APILogWorker.drain_all()
140
+ if inspect.isawaitable(result):
141
+ await result
138
142
 
139
143
  def emit(self, record: logging.LogRecord) -> None:
140
144
  """
@@ -202,11 +206,15 @@ class APILogHandler(logging.Handler):
202
206
  " flow run contexts unless the flow run id is manually provided."
203
207
  ) from None
204
208
 
205
- if hasattr(context, "flow_run"):
206
- flow_run_id = context.flow_run.id
207
- elif hasattr(context, "task_run"):
208
- flow_run_id = context.task_run.flow_run_id
209
- task_run_id = task_run_id or context.task_run.id
209
+ if flow_run := getattr(context, "flow_run", None):
210
+ if TYPE_CHECKING:
211
+ assert isinstance(flow_run, FlowRun)
212
+ flow_run_id = flow_run.id
213
+ elif task_run := getattr(context, "task_run", None):
214
+ if TYPE_CHECKING:
215
+ assert isinstance(task_run, TaskRun)
216
+ flow_run_id = task_run.flow_run_id
217
+ task_run_id = task_run_id or task_run.id
210
218
  else:
211
219
  raise ValueError(
212
220
  "Encountered malformed run context. Does not contain flow or task "
@@ -216,15 +224,14 @@ class APILogHandler(logging.Handler):
216
224
  # Parsing to a `LogCreate` object here gives us nice parsing error messages
217
225
  # from the standard lib `handleError` method if something goes wrong and
218
226
  # prevents malformed logs from entering the queue
219
- try:
220
- is_uuid_like = isinstance(flow_run_id, uuid.UUID) or (
221
- isinstance(flow_run_id, str) and uuid.UUID(flow_run_id)
222
- )
223
- except ValueError:
224
- is_uuid_like = False
227
+ if isinstance(flow_run_id, str):
228
+ try:
229
+ flow_run_id = uuid.UUID(flow_run_id)
230
+ except ValueError:
231
+ flow_run_id = None
225
232
 
226
233
  log = LogCreate(
227
- flow_run_id=flow_run_id if is_uuid_like else None,
234
+ flow_run_id=flow_run_id,
228
235
  task_run_id=task_run_id,
229
236
  worker_id=worker_id,
230
237
  name=record.name,
@@ -306,15 +313,19 @@ class PrefectConsoleHandler(StreamHandler):
306
313
  styled_console = PREFECT_LOGGING_COLORS.value()
307
314
  markup_console = PREFECT_LOGGING_MARKUP.value()
308
315
  if styled_console:
309
- highlighter = highlighter()
316
+ highlighter_instance = highlighter()
310
317
  theme = Theme(styles, inherit=False)
311
318
  else:
312
- highlighter = NullHighlighter()
319
+ highlighter_instance = NullHighlighter()
313
320
  theme = Theme(inherit=False)
314
321
 
315
- self.level = level
322
+ if isinstance(level, str):
323
+ self.level: int = logging.getLevelNamesMapping()[level]
324
+ else:
325
+ self.level: int = level
326
+
316
327
  self.console: Console = Console(
317
- highlighter=highlighter,
328
+ highlighter=highlighter_instance,
318
329
  theme=theme,
319
330
  file=self.stream,
320
331
  markup=markup_console,
prefect/results.py CHANGED
@@ -552,7 +552,11 @@ class ResultStore(BaseModel):
552
552
  if self.result_storage_block_id is None and (
553
553
  _resolve_path := getattr(self.result_storage, "_resolve_path", None)
554
554
  ):
555
- return str(_resolve_path(key))
555
+ path_key = _resolve_path(key)
556
+ if path_key is not None:
557
+ return str(_resolve_path(key))
558
+ else:
559
+ return key
556
560
  return key
557
561
 
558
562
  @sync_compatible
@@ -684,7 +688,9 @@ class ResultStore(BaseModel):
684
688
 
685
689
  if self.result_storage_block_id is None:
686
690
  if _resolve_path := getattr(self.result_storage, "_resolve_path", None):
687
- key = str(_resolve_path(key))
691
+ path_key = _resolve_path(key)
692
+ if path_key is not None:
693
+ key = str(_resolve_path(key))
688
694
 
689
695
  return ResultRecord(
690
696
  result=obj,
@@ -773,7 +779,7 @@ class ResultStore(BaseModel):
773
779
  )
774
780
  else Path(".").resolve()
775
781
  )
776
- base_key = str(Path(key).relative_to(basepath))
782
+ base_key = key if basepath is None else str(Path(key).relative_to(basepath))
777
783
  else:
778
784
  base_key = key
779
785
  if (
@@ -1,2 +1,4 @@
1
1
  from .runner import Runner
2
2
  from .submit import submit_to_runner, wait_for_submitted_runs
3
+
4
+ __all__ = ["Runner", "submit_to_runner", "wait_for_submitted_runs"]
prefect/runner/runner.py CHANGED
@@ -1425,7 +1425,7 @@ class Runner:
1425
1425
  )
1426
1426
  status_code = process.returncode
1427
1427
  except Exception as exc:
1428
- if not task_status._future.done():
1428
+ if not task_status._future.done(): # type: ignore
1429
1429
  # This flow run was being submitted and did not start successfully
1430
1430
  run_logger.exception(
1431
1431
  f"Failed to start process for flow run '{flow_run.id}'."
prefect/runner/server.py CHANGED
@@ -1,5 +1,7 @@
1
+ from __future__ import annotations
2
+
1
3
  import uuid
2
- from typing import TYPE_CHECKING, Any, Callable, Coroutine, Hashable, Optional, Tuple
4
+ from typing import TYPE_CHECKING, Any, Callable, Coroutine, Hashable, Optional
3
5
 
4
6
  import uvicorn
5
7
  from fastapi import APIRouter, FastAPI, HTTPException, status
@@ -33,7 +35,7 @@ if TYPE_CHECKING:
33
35
 
34
36
  from pydantic import BaseModel
35
37
 
36
- logger: "logging.Logger" = get_logger("webserver")
38
+ logger: "logging.Logger" = get_logger("runner.webserver")
37
39
 
38
40
  RunnableEndpoint = Literal["deployment", "flow", "task"]
39
41
 
@@ -45,7 +47,7 @@ class RunnerGenericFlowRunRequest(BaseModel):
45
47
 
46
48
 
47
49
  def perform_health_check(
48
- runner: "Runner", delay_threshold: Optional[int] = None
50
+ runner: "Runner", delay_threshold: int | None = None
49
51
  ) -> Callable[..., JSONResponse]:
50
52
  if delay_threshold is None:
51
53
  delay_threshold = (
@@ -57,6 +59,9 @@ def perform_health_check(
57
59
  now = DateTime.now("utc")
58
60
  poll_delay = (now - runner.last_polled).total_seconds()
59
61
 
62
+ if TYPE_CHECKING:
63
+ assert delay_threshold is not None
64
+
60
65
  if poll_delay > delay_threshold:
61
66
  return JSONResponse(
62
67
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
@@ -120,7 +125,7 @@ async def _build_endpoint_for_deployment(
120
125
 
121
126
  async def get_deployment_router(
122
127
  runner: "Runner",
123
- ) -> Tuple[APIRouter, dict[Hashable, Any]]:
128
+ ) -> tuple[APIRouter, dict[Hashable, Any]]:
124
129
  router = APIRouter()
125
130
  schemas: dict[Hashable, Any] = {}
126
131
  async with get_client() as client:
@@ -216,14 +221,14 @@ def _build_generic_endpoint_for_flows(
216
221
  # Verify that the flow we're loading is a subflow this runner is
217
222
  # managing
218
223
  if not _flow_in_schemas(flow, schemas):
219
- runner._logger.warning(
224
+ logger.warning(
220
225
  f"Flow {flow.name} is not directly managed by the runner. Please "
221
226
  "include it in the runner's served flows' import namespace."
222
227
  )
223
228
  # Verify that the flow we're loading hasn't changed since the webserver
224
229
  # was started
225
230
  if _flow_schema_changed(flow, schemas):
226
- runner._logger.warning(
231
+ logger.warning(
227
232
  "A change in flow parameters has been detected. Please "
228
233
  "restart the runner."
229
234
  )
@@ -291,7 +296,7 @@ async def build_server(runner: "Runner") -> FastAPI:
291
296
  return webserver
292
297
 
293
298
 
294
- def start_webserver(runner: "Runner", log_level: Optional[str] = None) -> None:
299
+ def start_webserver(runner: "Runner", log_level: str | None = None) -> None:
295
300
  """
296
301
  Run a FastAPI server for a runner.
297
302
 
prefect/runner/storage.py CHANGED
@@ -6,7 +6,6 @@ from copy import deepcopy
6
6
  from pathlib import Path
7
7
  from typing import (
8
8
  Any,
9
- Dict,
10
9
  Optional,
11
10
  Protocol,
12
11
  TypedDict,
@@ -16,7 +15,7 @@ from typing import (
16
15
  from urllib.parse import urlparse, urlsplit, urlunparse
17
16
  from uuid import uuid4
18
17
 
19
- import fsspec
18
+ import fsspec # pyright: ignore[reportMissingTypeStubs]
20
19
  from anyio import run_process
21
20
  from pydantic import SecretStr
22
21
 
@@ -79,7 +78,7 @@ class RunnerStorage(Protocol):
79
78
 
80
79
  class GitCredentials(TypedDict, total=False):
81
80
  username: str
82
- access_token: Union[str, Secret[str]]
81
+ access_token: str | Secret[str]
83
82
 
84
83
 
85
84
  class GitRepository:
@@ -117,12 +116,12 @@ class GitRepository:
117
116
  def __init__(
118
117
  self,
119
118
  url: str,
120
- credentials: Union[GitCredentials, Block, Dict[str, Any], None] = None,
121
- name: Optional[str] = None,
122
- branch: Optional[str] = None,
119
+ credentials: Union[GitCredentials, Block, dict[str, Any], None] = None,
120
+ name: str | None = None,
121
+ branch: str | None = None,
123
122
  include_submodules: bool = False,
124
- pull_interval: Optional[int] = 60,
125
- directories: Optional[str] = None,
123
+ pull_interval: int | None = 60,
124
+ directories: list[str] | None = None,
126
125
  ):
127
126
  if credentials is None:
128
127
  credentials = {}
@@ -198,7 +197,7 @@ class GitRepository:
198
197
  @property
199
198
  def _git_config(self) -> list[str]:
200
199
  """Build a git configuration to use when running git commands."""
201
- config = {}
200
+ config: dict[str, str] = {}
202
201
 
203
202
  # Submodules can be private. The url in .gitmodules
204
203
  # will not include the credentials, we need to
@@ -220,7 +219,7 @@ class GitRepository:
220
219
  result = await run_process(
221
220
  ["git", "config", "--get", "core.sparseCheckout"], cwd=self.destination
222
221
  )
223
- return result.strip().lower() == "true"
222
+ return result.stdout.decode().strip().lower() == "true"
224
223
  except Exception:
225
224
  return False
226
225
 
@@ -243,8 +242,7 @@ class GitRepository:
243
242
  cwd=str(self.destination),
244
243
  )
245
244
  existing_repo_url = None
246
- if result.stdout is not None:
247
- existing_repo_url = _strip_auth_from_url(result.stdout.decode().strip())
245
+ existing_repo_url = _strip_auth_from_url(result.stdout.decode().strip())
248
246
 
249
247
  if existing_repo_url != self._url:
250
248
  raise ValueError(
@@ -255,7 +253,7 @@ class GitRepository:
255
253
  # Sparsely checkout the repository if directories are specified and the repo is not in sparse-checkout mode already
256
254
  if self._directories and not await self.is_sparsely_checked_out():
257
255
  await run_process(
258
- ["git", "sparse-checkout", "set"] + self._directories,
256
+ ["git", "sparse-checkout", "set", *self._directories],
259
257
  cwd=self.destination,
260
258
  )
261
259
 
@@ -323,7 +321,7 @@ class GitRepository:
323
321
  if self._directories:
324
322
  self._logger.debug("Will add %s", self._directories)
325
323
  await run_process(
326
- ["git", "sparse-checkout", "set"] + self._directories,
324
+ ["git", "sparse-checkout", "set", *self._directories],
327
325
  cwd=self.destination,
328
326
  )
329
327
 
@@ -343,7 +341,7 @@ class GitRepository:
343
341
  )
344
342
 
345
343
  def to_pull_step(self) -> dict[str, Any]:
346
- pull_step = {
344
+ pull_step: dict[str, Any] = {
347
345
  "prefect.deployments.steps.git_clone": {
348
346
  "repository": self._url,
349
347
  "branch": self._branch,
@@ -357,13 +355,14 @@ class GitRepository:
357
355
  pull_step["prefect.deployments.steps.git_clone"]["credentials"] = (
358
356
  f"{{{{ {self._credentials.get_block_placeholder()} }}}}"
359
357
  )
360
- elif isinstance(self._credentials, dict):
361
- if isinstance(self._credentials.get("access_token"), Secret):
358
+ elif isinstance(self._credentials, dict): # pyright: ignore[reportUnnecessaryIsInstance]
359
+ if isinstance(
360
+ access_token := self._credentials.get("access_token"), Secret
361
+ ):
362
362
  pull_step["prefect.deployments.steps.git_clone"]["credentials"] = {
363
363
  **self._credentials,
364
364
  "access_token": (
365
- "{{"
366
- f" {self._credentials['access_token'].get_block_placeholder()} }}}}"
365
+ f"{{{{ {access_token.get_block_placeholder()} }}}}"
367
366
  ),
368
367
  }
369
368
  elif self._credentials.get("access_token") is not None:
@@ -455,10 +454,10 @@ class RemoteStorage:
455
454
 
456
455
  def replace_blocks_with_values(obj: Any) -> Any:
457
456
  if isinstance(obj, Block):
458
- if hasattr(obj, "get"):
459
- return obj.get()
457
+ if get := getattr(obj, "get", None):
458
+ return get()
460
459
  if hasattr(obj, "value"):
461
- return obj.value
460
+ return getattr(obj, "value")
462
461
  else:
463
462
  return obj.model_dump()
464
463
  return obj
@@ -467,7 +466,7 @@ class RemoteStorage:
467
466
  self._settings, replace_blocks_with_values, return_data=True
468
467
  )
469
468
 
470
- return fsspec.filesystem(scheme, **settings_with_block_values)
469
+ return fsspec.filesystem(scheme, **settings_with_block_values) # pyright: ignore[reportUnknownMemberType] missing type stubs
471
470
 
472
471
  def set_base_path(self, path: Path) -> None:
473
472
  self._storage_base_path = path
@@ -513,7 +512,7 @@ class RemoteStorage:
513
512
  try:
514
513
  await from_async.wait_for_call_in_new_thread(
515
514
  create_call(
516
- self._filesystem.get,
515
+ self._filesystem.get, # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType] missing type stubs
517
516
  remote_path,
518
517
  str(self.destination),
519
518
  recursive=True,
@@ -580,18 +579,14 @@ class BlockStorageAdapter:
580
579
  self._block = block
581
580
  self._pull_interval = pull_interval
582
581
  self._storage_base_path = Path.cwd()
583
- if not isinstance(block, Block):
582
+ if not isinstance(block, Block): # pyright: ignore[reportUnnecessaryIsInstance]
584
583
  raise TypeError(
585
584
  f"Expected a block object. Received a {type(block).__name__!r} object."
586
585
  )
587
586
  if not hasattr(block, "get_directory"):
588
587
  raise ValueError("Provided block must have a `get_directory` method.")
589
588
 
590
- self._name = (
591
- f"{block.get_block_type_slug()}-{block._block_document_name}"
592
- if block._block_document_name
593
- else str(uuid4())
594
- )
589
+ self._name = f"{block.get_block_type_slug()}-{getattr(block, '_block_document_name', uuid4())}"
595
590
 
596
591
  def set_base_path(self, path: Path) -> None:
597
592
  self._storage_base_path = path
@@ -610,11 +605,11 @@ class BlockStorageAdapter:
610
605
  await self._block.get_directory(local_path=str(self.destination))
611
606
 
612
607
  def to_pull_step(self) -> dict[str, Any]:
613
- # Give blocks the change to implement their own pull step
608
+ # Give blocks the chance to implement their own pull step
614
609
  if hasattr(self._block, "get_pull_step"):
615
- return self._block.get_pull_step()
610
+ return getattr(self._block, "get_pull_step")()
616
611
  else:
617
- if not self._block._block_document_name:
612
+ if getattr(self._block, "_block_document_name", None) is None:
618
613
  raise BlockNotSavedError(
619
614
  "Block must be saved with `.save()` before it can be converted to a"
620
615
  " pull step."
@@ -622,7 +617,7 @@ class BlockStorageAdapter:
622
617
  return {
623
618
  "prefect.deployments.steps.pull_with_block": {
624
619
  "block_type_slug": self._block.get_block_type_slug(),
625
- "block_document_name": self._block._block_document_name,
620
+ "block_document_name": getattr(self._block, "_block_document_name"),
626
621
  }
627
622
  }
628
623
 
@@ -723,7 +718,9 @@ def create_storage_from_source(
723
718
  return LocalStorage(path=source, pull_interval=pull_interval)
724
719
 
725
720
 
726
- def _format_token_from_credentials(netloc: str, credentials: dict) -> str:
721
+ def _format_token_from_credentials(
722
+ netloc: str, credentials: dict[str, Any] | GitCredentials
723
+ ) -> str:
727
724
  """
728
725
  Formats the credentials block for the git provider.
729
726
 
@@ -736,7 +733,10 @@ def _format_token_from_credentials(netloc: str, credentials: dict) -> str:
736
733
  token = credentials.get("token") if credentials else None
737
734
  access_token = credentials.get("access_token") if credentials else None
738
735
 
739
- user_provided_token = access_token or token or password
736
+ user_provided_token: str | Secret[str] | None = access_token or token or password
737
+
738
+ if isinstance(user_provided_token, Secret):
739
+ user_provided_token = user_provided_token.get()
740
740
 
741
741
  if not user_provided_token:
742
742
  raise ValueError(
@@ -787,7 +787,7 @@ def _strip_auth_from_url(url: str) -> str:
787
787
 
788
788
  # Construct a new netloc without the auth info
789
789
  netloc = parsed.hostname
790
- if parsed.port:
790
+ if parsed.port and netloc:
791
791
  netloc += f":{parsed.port}"
792
792
 
793
793
  # Build the sanitized URL
prefect/runner/submit.py CHANGED
@@ -3,15 +3,19 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import inspect
5
5
  import uuid
6
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, overload
6
+ from typing import TYPE_CHECKING, Any, Union, overload
7
7
 
8
8
  import anyio
9
9
  import httpx
10
10
  from typing_extensions import Literal, TypeAlias
11
11
 
12
12
  from prefect.client.orchestration import get_client
13
- from prefect.client.schemas.filters import FlowRunFilter, TaskRunFilter
14
- from prefect.client.schemas.objects import FlowRun
13
+ from prefect.client.schemas.filters import (
14
+ FlowRunFilter,
15
+ FlowRunFilterParentFlowRunId,
16
+ TaskRunFilter,
17
+ )
18
+ from prefect.client.schemas.objects import Constant, FlowRun, Parameter, TaskRunResult
15
19
  from prefect.context import FlowRunContext
16
20
  from prefect.flows import Flow
17
21
  from prefect.logging import get_logger
@@ -60,18 +64,20 @@ async def _submit_flow_to_runner(
60
64
 
61
65
  parent_flow_run_context = FlowRunContext.get()
62
66
 
63
- task_inputs = {
64
- k: await collect_task_run_inputs(v) for k, v in parameters.items()
67
+ task_inputs: dict[str, list[TaskRunResult | Parameter | Constant]] = {
68
+ k: list(await collect_task_run_inputs(v)) for k, v in parameters.items()
65
69
  }
66
70
  parameters = await resolve_inputs(parameters)
67
71
  dummy_task = Task(name=flow.name, fn=flow.fn, version=flow.version)
68
72
  parent_task_run = await client.create_task_run(
69
73
  task=dummy_task,
70
74
  flow_run_id=(
71
- parent_flow_run_context.flow_run.id if parent_flow_run_context else None
75
+ parent_flow_run_context.flow_run.id
76
+ if parent_flow_run_context and parent_flow_run_context.flow_run
77
+ else None
72
78
  ),
73
79
  dynamic_key=(
74
- dynamic_key_for_task_run(parent_flow_run_context, dummy_task)
80
+ str(dynamic_key_for_task_run(parent_flow_run_context, dummy_task))
75
81
  if parent_flow_run_context
76
82
  else str(uuid.uuid4())
77
83
  ),
@@ -79,14 +85,15 @@ async def _submit_flow_to_runner(
79
85
  state=Pending(),
80
86
  )
81
87
 
82
- response = await client._client.post(
88
+ httpx_client = getattr(client, "_client")
89
+ response = await httpx_client.post(
83
90
  (
84
91
  f"http://{PREFECT_RUNNER_SERVER_HOST.value()}"
85
92
  f":{PREFECT_RUNNER_SERVER_PORT.value()}"
86
93
  "/flow/run"
87
94
  ),
88
95
  json={
89
- "entrypoint": flow._entrypoint,
96
+ "entrypoint": getattr(flow, "_entrypoint"),
90
97
  "parameters": flow.serialize_parameters(parameters),
91
98
  "parent_task_run_id": str(parent_task_run.id),
92
99
  },
@@ -98,15 +105,15 @@ async def _submit_flow_to_runner(
98
105
 
99
106
  @overload
100
107
  def submit_to_runner(
101
- prefect_callable: Union[Flow[Any, Any], Task[Any, Any]],
102
- parameters: Dict[str, Any],
108
+ prefect_callable: Flow[Any, Any] | Task[Any, Any],
109
+ parameters: dict[str, Any],
103
110
  retry_failed_submissions: bool = True,
104
111
  ) -> FlowRun: ...
105
112
 
106
113
 
107
114
  @overload
108
115
  def submit_to_runner(
109
- prefect_callable: Union[Flow[Any, Any], Task[Any, Any]],
116
+ prefect_callable: Flow[Any, Any] | Task[Any, Any],
110
117
  parameters: list[dict[str, Any]],
111
118
  retry_failed_submissions: bool = True,
112
119
  ) -> list[FlowRun]: ...
@@ -114,10 +121,10 @@ def submit_to_runner(
114
121
 
115
122
  @sync_compatible
116
123
  async def submit_to_runner(
117
- prefect_callable: FlowOrTask,
118
- parameters: Optional[Union[dict[str, Any], list[dict[str, Any]]]] = None,
124
+ prefect_callable: Flow[Any, Any],
125
+ parameters: dict[str, Any] | list[dict[str, Any]] | None = None,
119
126
  retry_failed_submissions: bool = True,
120
- ) -> Union[FlowRun, list[FlowRun]]:
127
+ ) -> FlowRun | list[FlowRun]:
121
128
  """
122
129
  Submit a callable in the background via the runner webserver one or more times.
123
130
 
@@ -127,22 +134,22 @@ async def submit_to_runner(
127
134
  each dictionary represents a discrete invocation of the callable
128
135
  retry_failed_submissions: Whether to retry failed submissions to the runner webserver.
129
136
  """
130
- if not isinstance(prefect_callable, (Flow, Task)):
137
+ if not isinstance(prefect_callable, Flow): # pyright: ignore[reportUnnecessaryIsInstance]
131
138
  raise TypeError(
132
139
  "The `submit_to_runner` utility only supports submitting flows and tasks."
133
140
  )
134
141
 
135
142
  parameters = parameters or {}
136
- if isinstance(parameters, List):
143
+ if isinstance(parameters, list):
137
144
  return_single = False
138
- elif isinstance(parameters, dict):
145
+ elif isinstance(parameters, dict): # pyright: ignore[reportUnnecessaryIsInstance]
139
146
  parameters = [parameters]
140
147
  return_single = True
141
148
  else:
142
149
  raise TypeError("Parameters must be a dictionary or a list of dictionaries.")
143
150
 
144
- submitted_runs = []
145
- unsubmitted_parameters = []
151
+ submitted_runs: list[FlowRun] = []
152
+ unsubmitted_parameters: list[dict[str, Any]] = []
146
153
 
147
154
  for p in parameters:
148
155
  try:
@@ -181,9 +188,9 @@ async def submit_to_runner(
181
188
 
182
189
  @sync_compatible
183
190
  async def wait_for_submitted_runs(
184
- flow_run_filter: Optional[FlowRunFilter] = None,
185
- task_run_filter: Optional[TaskRunFilter] = None,
186
- timeout: Optional[float] = None,
191
+ flow_run_filter: FlowRunFilter | None = None,
192
+ task_run_filter: TaskRunFilter | None = None,
193
+ timeout: float | None = None,
187
194
  poll_interval: float = 3.0,
188
195
  ):
189
196
  """
@@ -197,7 +204,9 @@ async def wait_for_submitted_runs(
197
204
  poll_interval: How long to wait between polling each run's state (seconds).
198
205
  """
199
206
 
200
- parent_flow_run_id = ctx.flow_run.id if (ctx := FlowRunContext.get()) else None
207
+ parent_flow_run_id = (
208
+ ctx.flow_run.id if ((ctx := FlowRunContext.get()) and ctx.flow_run) else None
209
+ )
201
210
 
202
211
  if task_run_filter:
203
212
  raise NotImplementedError("Waiting for task runs is not yet supported.")
@@ -223,7 +232,9 @@ async def wait_for_submitted_runs(
223
232
  if parent_flow_run_id is not None:
224
233
  subflow_runs = await client.read_flow_runs(
225
234
  flow_run_filter=FlowRunFilter(
226
- parent_flow_run_id=dict(any_=[parent_flow_run_id])
235
+ parent_flow_run_id=FlowRunFilterParentFlowRunId(
236
+ any_=[parent_flow_run_id]
237
+ )
227
238
  )
228
239
  )
229
240