prefect-client 3.2.12__py3-none-any.whl → 3.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefect/workers/base.py CHANGED
@@ -3,31 +3,31 @@ from __future__ import annotations
3
3
  import abc
4
4
  import asyncio
5
5
  import threading
6
+ import warnings
6
7
  from contextlib import AsyncExitStack
7
8
  from functools import partial
8
9
  from typing import (
9
10
  TYPE_CHECKING,
10
11
  Any,
11
12
  Callable,
12
- Dict,
13
13
  Generic,
14
- List,
15
14
  Optional,
16
- Set,
17
15
  Type,
18
- Union,
19
16
  )
20
17
  from uuid import UUID, uuid4
21
18
 
22
19
  import anyio
23
20
  import anyio.abc
24
21
  import httpx
25
- from importlib_metadata import distributions
22
+ from importlib_metadata import (
23
+ distributions, # type: ignore[reportUnknownVariableType] incomplete typing
24
+ )
26
25
  from pydantic import BaseModel, Field, PrivateAttr, field_validator
27
26
  from pydantic.json_schema import GenerateJsonSchema
28
27
  from typing_extensions import Literal, Self, TypeVar
29
28
 
30
29
  import prefect
30
+ from prefect._internal.compatibility.deprecated import PrefectDeprecationWarning
31
31
  from prefect._internal.schemas.validators import return_v_or_none
32
32
  from prefect.client.base import ServerType
33
33
  from prefect.client.orchestration import PrefectClient, get_client
@@ -60,6 +60,7 @@ from prefect.settings import (
60
60
  get_current_settings,
61
61
  )
62
62
  from prefect.states import (
63
+ Cancelled,
63
64
  Crashed,
64
65
  Pending,
65
66
  exception_to_failed_state,
@@ -94,12 +95,12 @@ class BaseJobConfiguration(BaseModel):
94
95
  "will be automatically generated by the worker."
95
96
  ),
96
97
  )
97
- env: Dict[str, Optional[str]] = Field(
98
+ env: dict[str, Optional[str]] = Field(
98
99
  default_factory=dict,
99
100
  title="Environment Variables",
100
101
  description="Environment variables to set when starting a flow run.",
101
102
  )
102
- labels: Dict[str, str] = Field(
103
+ labels: dict[str, str] = Field(
103
104
  default_factory=dict,
104
105
  description=(
105
106
  "Labels applied to infrastructure created by the worker using "
@@ -114,7 +115,7 @@ class BaseJobConfiguration(BaseModel):
114
115
  ),
115
116
  )
116
117
 
117
- _related_objects: Dict[str, Any] = PrivateAttr(default_factory=dict)
118
+ _related_objects: dict[str, Any] = PrivateAttr(default_factory=dict)
118
119
 
119
120
  @property
120
121
  def is_using_a_runner(self) -> bool:
@@ -122,18 +123,18 @@ class BaseJobConfiguration(BaseModel):
122
123
 
123
124
  @field_validator("command")
124
125
  @classmethod
125
- def _coerce_command(cls, v):
126
+ def _coerce_command(cls, v: str | None) -> str | None:
126
127
  return return_v_or_none(v)
127
128
 
128
129
  @field_validator("env", mode="before")
129
130
  @classmethod
130
- def _coerce_env(cls, v):
131
+ def _coerce_env(cls, v: dict[str, Any]) -> dict[str, str | None]:
131
132
  return {k: str(v) if v is not None else None for k, v in v.items()}
132
133
 
133
134
  @staticmethod
134
- def _get_base_config_defaults(variables: dict) -> dict:
135
+ def _get_base_config_defaults(variables: dict[str, Any]) -> dict[str, Any]:
135
136
  """Get default values from base config for all variables that have them."""
136
- defaults = dict()
137
+ defaults: dict[str, Any] = {}
137
138
  for variable_name, attrs in variables.items():
138
139
  # We remote `None` values because we don't want to use them in templating.
139
140
  # The currently logic depends on keys not existing to populate the correct value
@@ -149,9 +150,9 @@ class BaseJobConfiguration(BaseModel):
149
150
  @inject_client
150
151
  async def from_template_and_values(
151
152
  cls,
152
- base_job_template: dict,
153
- values: dict,
154
- client: Optional["PrefectClient"] = None,
153
+ base_job_template: dict[str, Any],
154
+ values: dict[str, Any],
155
+ client: "PrefectClient | None" = None,
155
156
  ):
156
157
  """Creates a valid worker configuration object from the provided base
157
158
  configuration and overrides.
@@ -159,7 +160,7 @@ class BaseJobConfiguration(BaseModel):
159
160
  Important: this method expects that the base_job_template was already
160
161
  validated server-side.
161
162
  """
162
- base_config: Dict[str, Any] = base_job_template["job_configuration"]
163
+ base_config: dict[str, Any] = base_job_template["job_configuration"]
163
164
  variables_schema = base_job_template["variables"]
164
165
  variables = cls._get_base_config_defaults(
165
166
  variables_schema.get("properties", {})
@@ -213,8 +214,10 @@ class BaseJobConfiguration(BaseModel):
213
214
  def prepare_for_flow_run(
214
215
  self,
215
216
  flow_run: "FlowRun",
216
- deployment: Optional["DeploymentResponse"] = None,
217
- flow: Optional["Flow"] = None,
217
+ deployment: "DeploymentResponse | None" = None,
218
+ flow: "Flow | None" = None,
219
+ work_pool: "WorkPool | None" = None,
220
+ worker_name: str | None = None,
218
221
  ) -> None:
219
222
  """
220
223
  Prepare the job configuration for a flow run.
@@ -227,6 +230,8 @@ class BaseJobConfiguration(BaseModel):
227
230
  flow_run: The flow run to be executed.
228
231
  deployment: The deployment that the flow run is associated with.
229
232
  flow: The flow that the flow run is associated with.
233
+ work_pool: The work pool that the flow run is running in.
234
+ worker_name: The name of the worker that is submitting the flow run.
230
235
  """
231
236
 
232
237
  self._related_objects = {
@@ -234,26 +239,19 @@ class BaseJobConfiguration(BaseModel):
234
239
  "flow": flow,
235
240
  "flow-run": flow_run,
236
241
  }
237
- if deployment is not None:
238
- deployment_labels = self._base_deployment_labels(deployment)
239
- else:
240
- deployment_labels = {}
241
-
242
- if flow is not None:
243
- flow_labels = self._base_flow_labels(flow)
244
- else:
245
- flow_labels = {}
246
242
 
247
243
  env = {
248
244
  **self._base_environment(),
249
245
  **self._base_flow_run_environment(flow_run),
250
- **(self.env if isinstance(self.env, dict) else {}),
246
+ **(self.env if isinstance(self.env, dict) else {}), # pyright: ignore[reportUnnecessaryIsInstance]
251
247
  }
252
248
  self.env = {key: value for key, value in env.items() if value is not None}
253
249
  self.labels = {
254
250
  **self._base_flow_run_labels(flow_run),
255
- **deployment_labels,
256
- **flow_labels,
251
+ **self._base_work_pool_labels(work_pool),
252
+ **self._base_worker_name_label(worker_name),
253
+ **self._base_flow_labels(flow),
254
+ **self._base_deployment_labels(deployment),
257
255
  **self.labels,
258
256
  }
259
257
  self.name = self.name or flow_run.name
@@ -267,7 +265,7 @@ class BaseJobConfiguration(BaseModel):
267
265
  return "prefect flow-run execute"
268
266
 
269
267
  @staticmethod
270
- def _base_flow_run_labels(flow_run: "FlowRun") -> Dict[str, str]:
268
+ def _base_flow_run_labels(flow_run: "FlowRun") -> dict[str, str]:
271
269
  """
272
270
  Generate a dictionary of labels for a flow run job.
273
271
  """
@@ -278,7 +276,7 @@ class BaseJobConfiguration(BaseModel):
278
276
  }
279
277
 
280
278
  @classmethod
281
- def _base_environment(cls) -> Dict[str, str]:
279
+ def _base_environment(cls) -> dict[str, str]:
282
280
  """
283
281
  Environment variables that should be passed to all created infrastructure.
284
282
 
@@ -287,14 +285,22 @@ class BaseJobConfiguration(BaseModel):
287
285
  return get_current_settings().to_environment_variables(exclude_unset=True)
288
286
 
289
287
  @staticmethod
290
- def _base_flow_run_environment(flow_run: "FlowRun") -> Dict[str, str]:
288
+ def _base_flow_run_environment(flow_run: "FlowRun | None") -> dict[str, str]:
291
289
  """
292
290
  Generate a dictionary of environment variables for a flow run job.
293
291
  """
292
+ if flow_run is None:
293
+ return {}
294
+
294
295
  return {"PREFECT__FLOW_RUN_ID": str(flow_run.id)}
295
296
 
296
297
  @staticmethod
297
- def _base_deployment_labels(deployment: "DeploymentResponse") -> Dict[str, str]:
298
+ def _base_deployment_labels(
299
+ deployment: "DeploymentResponse | None",
300
+ ) -> dict[str, str]:
301
+ if deployment is None:
302
+ return {}
303
+
298
304
  labels = {
299
305
  "prefect.io/deployment-id": str(deployment.id),
300
306
  "prefect.io/deployment-name": deployment.name,
@@ -306,15 +312,37 @@ class BaseJobConfiguration(BaseModel):
306
312
  return labels
307
313
 
308
314
  @staticmethod
309
- def _base_flow_labels(flow: "Flow") -> Dict[str, str]:
315
+ def _base_flow_labels(flow: "Flow | None") -> dict[str, str]:
316
+ if flow is None:
317
+ return {}
318
+
310
319
  return {
311
320
  "prefect.io/flow-id": str(flow.id),
312
321
  "prefect.io/flow-name": flow.name,
313
322
  }
314
323
 
315
- def _related_resources(self) -> List[RelatedResource]:
316
- tags = set()
317
- related = []
324
+ @staticmethod
325
+ def _base_work_pool_labels(work_pool: "WorkPool | None") -> dict[str, str]:
326
+ """Adds the work pool labels to the job manifest."""
327
+ if work_pool is None:
328
+ return {}
329
+
330
+ return {
331
+ "prefect.io/work-pool-name": work_pool.name,
332
+ "prefect.io/work-pool-id": str(work_pool.id),
333
+ }
334
+
335
+ @staticmethod
336
+ def _base_worker_name_label(worker_name: str | None) -> dict[str, str]:
337
+ """Adds the worker name label to the job manifest."""
338
+ if worker_name is None:
339
+ return {}
340
+
341
+ return {"prefect.io/worker-name": worker_name}
342
+
343
+ def _related_resources(self) -> list[RelatedResource]:
344
+ tags: set[str] = set()
345
+ related: list[RelatedResource] = []
318
346
 
319
347
  for kind, obj in self._related_objects.items():
320
348
  if obj is None:
@@ -331,12 +359,12 @@ class BaseVariables(BaseModel):
331
359
  default=None,
332
360
  description="Name given to infrastructure created by a worker.",
333
361
  )
334
- env: Dict[str, Optional[str]] = Field(
362
+ env: dict[str, Optional[str]] = Field(
335
363
  default_factory=dict,
336
364
  title="Environment Variables",
337
365
  description="Environment variables to set when starting a flow run.",
338
366
  )
339
- labels: Dict[str, str] = Field(
367
+ labels: dict[str, str] = Field(
340
368
  default_factory=dict,
341
369
  description="Labels applied to infrastructure created by a worker.",
342
370
  )
@@ -356,7 +384,7 @@ class BaseVariables(BaseModel):
356
384
  ref_template: str = "#/definitions/{model}",
357
385
  schema_generator: Type[GenerateJsonSchema] = GenerateJsonSchema,
358
386
  mode: Literal["validation", "serialization"] = "validation",
359
- ) -> Dict[str, Any]:
387
+ ) -> dict[str, Any]:
360
388
  """TODO: stop overriding this method - use GenerateSchema in ConfigDict instead?"""
361
389
  schema = super().model_json_schema(
362
390
  by_alias, ref_template, schema_generator, mode
@@ -403,14 +431,14 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
403
431
  def __init__(
404
432
  self,
405
433
  work_pool_name: str,
406
- work_queues: Optional[List[str]] = None,
407
- name: Optional[str] = None,
408
- prefetch_seconds: Optional[float] = None,
434
+ work_queues: list[str] | None = None,
435
+ name: str | None = None,
436
+ prefetch_seconds: float | None = None,
409
437
  create_pool_if_not_found: bool = True,
410
- limit: Optional[int] = None,
411
- heartbeat_interval_seconds: Optional[int] = None,
438
+ limit: int | None = None,
439
+ heartbeat_interval_seconds: int | None = None,
412
440
  *,
413
- base_job_template: Optional[Dict[str, Any]] = None,
441
+ base_job_template: dict[str, Any] | None = None,
414
442
  ):
415
443
  """
416
444
  Base class for all Prefect workers.
@@ -445,7 +473,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
445
473
  self._create_pool_if_not_found = create_pool_if_not_found
446
474
  self._base_job_template = base_job_template
447
475
  self._work_pool_name = work_pool_name
448
- self._work_queues: Set[str] = set(work_queues) if work_queues else set()
476
+ self._work_queues: set[str] = set(work_queues) if work_queues else set()
449
477
 
450
478
  self._prefetch_seconds: float = (
451
479
  prefetch_seconds or PREFECT_WORKER_PREFETCH_SECONDS.value()
@@ -461,11 +489,35 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
461
489
  self._last_polled_time: DateTime = DateTime.now("utc")
462
490
  self._limit = limit
463
491
  self._limiter: Optional[anyio.CapacityLimiter] = None
464
- self._submitting_flow_run_ids = set()
465
- self._cancelling_flow_run_ids = set()
466
- self._scheduled_task_scopes = set()
492
+ self._submitting_flow_run_ids: set[UUID] = set()
493
+ self._cancelling_flow_run_ids: set[UUID] = set()
494
+ self._scheduled_task_scopes: set[anyio.CancelScope] = set()
467
495
  self._worker_metadata_sent = False
468
496
 
497
+ @property
498
+ def client(self) -> PrefectClient:
499
+ if self._client is None:
500
+ raise RuntimeError(
501
+ "Worker has not been correctly initialized. Please use the worker class as an async context manager."
502
+ )
503
+ return self._client
504
+
505
+ @property
506
+ def work_pool(self) -> WorkPool:
507
+ if self._work_pool is None:
508
+ raise RuntimeError(
509
+ "Worker has not been correctly initialized. Please use the worker class as an async context manager."
510
+ )
511
+ return self._work_pool
512
+
513
+ @property
514
+ def limiter(self) -> anyio.CapacityLimiter:
515
+ if self._limiter is None:
516
+ raise RuntimeError(
517
+ "Worker has not been correctly initialized. Please use the worker class as an async context manager."
518
+ )
519
+ return self._limiter
520
+
469
521
  @classmethod
470
522
  def get_documentation_url(cls) -> str:
471
523
  return cls._documentation_url
@@ -510,7 +562,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
510
562
  return worker_registry.get(type)
511
563
 
512
564
  @staticmethod
513
- def get_all_available_worker_types() -> List[str]:
565
+ def get_all_available_worker_types() -> list[str]:
514
566
  """
515
567
  Returns all worker types available in the local registry.
516
568
  """
@@ -790,7 +842,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
790
842
 
791
843
  should_get_worker_id = self._should_get_worker_id()
792
844
 
793
- params = {
845
+ params: dict[str, Any] = {
794
846
  "work_pool_name": self._work_pool_name,
795
847
  "worker_name": self.name,
796
848
  "heartbeat_interval_seconds": self.heartbeat_interval_seconds,
@@ -852,7 +904,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
852
904
 
853
905
  async def _get_scheduled_flow_runs(
854
906
  self,
855
- ) -> List["WorkerFlowRunResponse"]:
907
+ ) -> list["WorkerFlowRunResponse"]:
856
908
  """
857
909
  Retrieve scheduled flow runs from the work pool's queues.
858
910
  """
@@ -862,7 +914,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
862
914
  )
863
915
  try:
864
916
  scheduled_flow_runs = (
865
- await self._client.get_scheduled_flow_runs_for_work_pool(
917
+ await self.client.get_scheduled_flow_runs_for_work_pool(
866
918
  work_pool_name=self._work_pool_name,
867
919
  scheduled_before=scheduled_before,
868
920
  work_queue_names=list(self._work_queues),
@@ -878,8 +930,8 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
878
930
  return []
879
931
 
880
932
  async def _submit_scheduled_flow_runs(
881
- self, flow_run_response: List["WorkerFlowRunResponse"]
882
- ) -> List["FlowRun"]:
933
+ self, flow_run_response: list["WorkerFlowRunResponse"]
934
+ ) -> list["FlowRun"]:
883
935
  """
884
936
  Takes a list of WorkerFlowRunResponses and submits the referenced flow runs
885
937
  for execution by the worker.
@@ -897,7 +949,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
897
949
  self._limiter.acquire_on_behalf_of_nowait(flow_run.id)
898
950
  except anyio.WouldBlock:
899
951
  self._logger.info(
900
- f"Flow run limit reached; {self._limiter.borrowed_tokens} flow runs"
952
+ f"Flow run limit reached; {self.limiter.borrowed_tokens} flow runs"
901
953
  " in progress."
902
954
  )
903
955
  break
@@ -921,6 +973,8 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
921
973
  run_logger.warning(f"Failed to generate worker URL: {ve}")
922
974
 
923
975
  self._submitting_flow_run_ids.add(flow_run.id)
976
+ if TYPE_CHECKING:
977
+ assert self._runs_task_group is not None
924
978
  self._runs_task_group.start_soon(
925
979
  self._submit_run,
926
980
  flow_run,
@@ -939,14 +993,9 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
939
993
  """
940
994
  run_logger = self.get_flow_run_logger(flow_run)
941
995
 
942
- if flow_run.deployment_id:
943
- assert self._client and self._client._started, (
944
- "Client must be started to check flow run deployment."
945
- )
946
-
947
996
  try:
948
- await self._client.read_deployment(flow_run.deployment_id)
949
- except ObjectNotFound:
997
+ await self.client.read_deployment(getattr(flow_run, "deployment_id"))
998
+ except (ObjectNotFound, AttributeError):
950
999
  self._logger.exception(
951
1000
  f"Deployment {flow_run.deployment_id} no longer exists. "
952
1001
  f"Flow run {flow_run.id} will not be submitted for"
@@ -964,13 +1013,15 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
964
1013
  ready_to_submit = await self._propose_pending_state(flow_run)
965
1014
  self._logger.debug(f"Ready to submit {flow_run.id}: {ready_to_submit}")
966
1015
  if ready_to_submit:
1016
+ if TYPE_CHECKING:
1017
+ assert self._runs_task_group is not None
967
1018
  readiness_result = await self._runs_task_group.start(
968
1019
  self._submit_run_and_capture_errors, flow_run
969
1020
  )
970
1021
 
971
1022
  if readiness_result and not isinstance(readiness_result, Exception):
972
1023
  try:
973
- await self._client.update_flow_run(
1024
+ await self.client.update_flow_run(
974
1025
  flow_run_id=flow_run.id,
975
1026
  infrastructure_pid=str(readiness_result),
976
1027
  )
@@ -991,8 +1042,10 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
991
1042
  self._submitting_flow_run_ids.remove(flow_run.id)
992
1043
 
993
1044
  async def _submit_run_and_capture_errors(
994
- self, flow_run: "FlowRun", task_status: Optional[anyio.abc.TaskStatus] = None
995
- ) -> Union[BaseWorkerResult, Exception]:
1045
+ self,
1046
+ flow_run: "FlowRun",
1047
+ task_status: anyio.abc.TaskStatus[int | Exception] | None = None,
1048
+ ) -> BaseWorkerResult | Exception:
996
1049
  run_logger = self.get_flow_run_logger(flow_run)
997
1050
 
998
1051
  try:
@@ -1006,7 +1059,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1006
1059
  configuration=configuration,
1007
1060
  )
1008
1061
  except Exception as exc:
1009
- if not task_status._future.done():
1062
+ if task_status and not getattr(task_status, "_future").done():
1010
1063
  # This flow run was being submitted and did not start successfully
1011
1064
  run_logger.exception(
1012
1065
  f"Failed to submit flow run '{flow_run.id}' to infrastructure."
@@ -1025,7 +1078,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1025
1078
  finally:
1026
1079
  self._release_limit_slot(flow_run.id)
1027
1080
 
1028
- if not task_status._future.done():
1081
+ if task_status and not getattr(task_status, "_future").done():
1029
1082
  run_logger.error(
1030
1083
  f"Infrastructure returned without reporting flow run '{flow_run.id}' "
1031
1084
  "as started or raising an error. This behavior is not expected and "
@@ -1033,7 +1086,11 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1033
1086
  "flow run will not be marked as failed, but an issue may have occurred."
1034
1087
  )
1035
1088
  # Mark the task as started to prevent agent crash
1036
- task_status.started()
1089
+ task_status.started(
1090
+ RuntimeError(
1091
+ "Infrastructure returned without reporting flow run as started or raising an error."
1092
+ )
1093
+ )
1037
1094
 
1038
1095
  if result.status_code != 0:
1039
1096
  await self._propose_crashed_state(
@@ -1044,11 +1101,12 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1044
1101
  ),
1045
1102
  )
1046
1103
 
1047
- self._emit_flow_run_executed_event(result, configuration, submitted_event)
1104
+ if submitted_event:
1105
+ self._emit_flow_run_executed_event(result, configuration, submitted_event)
1048
1106
 
1049
1107
  return result
1050
1108
 
1051
- def _release_limit_slot(self, flow_run_id: str) -> None:
1109
+ def _release_limit_slot(self, flow_run_id: UUID) -> None:
1052
1110
  """
1053
1111
  Frees up a slot taken by the given flow run id.
1054
1112
  """
@@ -1078,14 +1136,12 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1078
1136
  flow_run: "FlowRun",
1079
1137
  deployment: Optional["DeploymentResponse"] = None,
1080
1138
  ) -> C:
1081
- deployment = (
1082
- deployment
1083
- if deployment
1084
- else await self._client.read_deployment(flow_run.deployment_id)
1085
- )
1086
- flow = await self._client.read_flow(flow_run.flow_id)
1139
+ if not deployment and flow_run.deployment_id:
1140
+ deployment = await self.client.read_deployment(flow_run.deployment_id)
1087
1141
 
1088
- deployment_vars = deployment.job_variables or {}
1142
+ flow = await self.client.read_flow(flow_run.flow_id)
1143
+
1144
+ deployment_vars = getattr(deployment, "job_variables", {}) or {}
1089
1145
  flow_run_vars = flow_run.job_variables or {}
1090
1146
  job_variables = {**deployment_vars}
1091
1147
 
@@ -1095,22 +1151,37 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1095
1151
  job_variables.update(flow_run_vars)
1096
1152
 
1097
1153
  configuration = await self.job_configuration.from_template_and_values(
1098
- base_job_template=self._work_pool.base_job_template,
1154
+ base_job_template=self.work_pool.base_job_template,
1099
1155
  values=job_variables,
1100
- client=self._client,
1101
- )
1102
- configuration.prepare_for_flow_run(
1103
- flow_run=flow_run, deployment=deployment, flow=flow
1156
+ client=self.client,
1104
1157
  )
1158
+ try:
1159
+ configuration.prepare_for_flow_run(
1160
+ flow_run=flow_run,
1161
+ deployment=deployment,
1162
+ flow=flow,
1163
+ work_pool=self.work_pool,
1164
+ worker_name=self.name,
1165
+ )
1166
+ except TypeError:
1167
+ warnings.warn(
1168
+ "This worker is missing the `work_pool` and `worker_name` arguments "
1169
+ "in its JobConfiguration.prepare_for_flow_run method. Please update "
1170
+ "the worker's JobConfiguration class to accept these arguments to "
1171
+ "avoid this warning.",
1172
+ category=PrefectDeprecationWarning,
1173
+ )
1174
+ # Handle older subclasses that don't accept work_pool and worker_name
1175
+ configuration.prepare_for_flow_run(
1176
+ flow_run=flow_run, deployment=deployment, flow=flow
1177
+ )
1105
1178
  return configuration
1106
1179
 
1107
1180
  async def _propose_pending_state(self, flow_run: "FlowRun") -> bool:
1108
1181
  run_logger = self.get_flow_run_logger(flow_run)
1109
1182
  state = flow_run.state
1110
1183
  try:
1111
- state = await propose_state(
1112
- self._client, Pending(), flow_run_id=flow_run.id
1113
- )
1184
+ state = await propose_state(self.client, Pending(), flow_run_id=flow_run.id)
1114
1185
  except Abort as exc:
1115
1186
  run_logger.info(
1116
1187
  (
@@ -1141,7 +1212,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1141
1212
  run_logger = self.get_flow_run_logger(flow_run)
1142
1213
  try:
1143
1214
  await propose_state(
1144
- self._client,
1215
+ self.client,
1145
1216
  await exception_to_failed_state(message="Submission failed.", exc=exc),
1146
1217
  flow_run_id=flow_run.id,
1147
1218
  )
@@ -1159,7 +1230,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1159
1230
  run_logger = self.get_flow_run_logger(flow_run)
1160
1231
  try:
1161
1232
  state = await propose_state(
1162
- self._client,
1233
+ self.client,
1163
1234
  Crashed(message=message),
1164
1235
  flow_run_id=flow_run.id,
1165
1236
  )
@@ -1175,14 +1246,19 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1175
1246
  )
1176
1247
 
1177
1248
  async def _mark_flow_run_as_cancelled(
1178
- self, flow_run: "FlowRun", state_updates: Optional[dict] = None
1249
+ self, flow_run: "FlowRun", state_updates: dict[str, Any] | None = None
1179
1250
  ) -> None:
1180
1251
  state_updates = state_updates or {}
1181
1252
  state_updates.setdefault("name", "Cancelled")
1182
1253
  state_updates.setdefault("type", StateType.CANCELLED)
1183
- state = flow_run.state.model_copy(update=state_updates)
1184
1254
 
1185
- await self._client.set_flow_run_state(flow_run.id, state, force=True)
1255
+ if flow_run.state:
1256
+ state = flow_run.state.model_copy(update=state_updates)
1257
+ else:
1258
+ # Unexpectedly when flow run does not have a state, create a new one
1259
+ state = Cancelled(**state_updates)
1260
+
1261
+ await self.client.set_flow_run_state(flow_run.id, state, force=True)
1186
1262
 
1187
1263
  # Do not remove the flow run from the cancelling set immediately because
1188
1264
  # the API caches responses for the `read_flow_runs` and we do not want to
@@ -1191,16 +1267,21 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1191
1267
  60 * 10, self._cancelling_flow_run_ids.remove, flow_run.id
1192
1268
  )
1193
1269
 
1194
- async def _set_work_pool_template(self, work_pool, job_template):
1270
+ async def _set_work_pool_template(
1271
+ self, work_pool: "WorkPool", job_template: dict[str, Any]
1272
+ ):
1195
1273
  """Updates the `base_job_template` for the worker's work pool server side."""
1196
- await self._client.update_work_pool(
1274
+
1275
+ await self.client.update_work_pool(
1197
1276
  work_pool_name=work_pool.name,
1198
1277
  work_pool=WorkPoolUpdate(
1199
1278
  base_job_template=job_template,
1200
1279
  ),
1201
1280
  )
1202
1281
 
1203
- async def _schedule_task(self, __in_seconds: int, fn, *args, **kwargs):
1282
+ async def _schedule_task(
1283
+ self, __in_seconds: int, fn: Callable[..., Any], *args: Any, **kwargs: Any
1284
+ ):
1204
1285
  """
1205
1286
  Schedule a background task to start after some time.
1206
1287
 
@@ -1208,8 +1289,12 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1208
1289
 
1209
1290
  The function may be async or sync. Async functions will be awaited.
1210
1291
  """
1292
+ if not self._runs_task_group:
1293
+ raise RuntimeError(
1294
+ "Worker has not been correctly initialized. Please use the worker class as an async context manager."
1295
+ )
1211
1296
 
1212
- async def wrapper(task_status):
1297
+ async def wrapper(task_status: anyio.abc.TaskStatus[Any]):
1213
1298
  # If we are shutting down, do not sleep; otherwise sleep until the scheduled
1214
1299
  # time or shutdown
1215
1300
  if self.is_setup:
@@ -1271,12 +1356,12 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1271
1356
 
1272
1357
  def _event_related_resources(
1273
1358
  self,
1274
- configuration: Optional[BaseJobConfiguration] = None,
1359
+ configuration: BaseJobConfiguration | None = None,
1275
1360
  include_self: bool = False,
1276
- ) -> List[RelatedResource]:
1277
- related = []
1361
+ ) -> list[RelatedResource]:
1362
+ related: list[RelatedResource] = []
1278
1363
  if configuration:
1279
- related += configuration._related_resources()
1364
+ related += getattr(configuration, "_related_resources")()
1280
1365
 
1281
1366
  if self._work_pool:
1282
1367
  related.append(
@@ -1294,7 +1379,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1294
1379
 
1295
1380
  def _emit_flow_run_submitted_event(
1296
1381
  self, configuration: BaseJobConfiguration
1297
- ) -> Event:
1382
+ ) -> Event | None:
1298
1383
  return emit_event(
1299
1384
  event="prefect.worker.submitted-flow-run",
1300
1385
  resource=self._event_resource(),
@@ -1305,7 +1390,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1305
1390
  self,
1306
1391
  result: BaseWorkerResult,
1307
1392
  configuration: BaseJobConfiguration,
1308
- submitted_event: Event,
1393
+ submitted_event: Event | None = None,
1309
1394
  ):
1310
1395
  related = self._event_related_resources(configuration=configuration)
1311
1396
 
@@ -1321,7 +1406,7 @@ class BaseWorker(abc.ABC, Generic[C, V, R]):
1321
1406
  follows=submitted_event,
1322
1407
  )
1323
1408
 
1324
- async def _emit_worker_started_event(self) -> Event:
1409
+ async def _emit_worker_started_event(self) -> Event | None:
1325
1410
  return emit_event(
1326
1411
  "prefect.worker.started",
1327
1412
  resource=self._event_resource(),