prefect-client 3.1.6__py3-none-any.whl → 3.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. prefect/_experimental/__init__.py +0 -0
  2. prefect/_experimental/lineage.py +181 -0
  3. prefect/_internal/compatibility/async_dispatch.py +38 -9
  4. prefect/_internal/pydantic/v2_validated_func.py +15 -10
  5. prefect/_internal/retries.py +15 -6
  6. prefect/_internal/schemas/bases.py +2 -1
  7. prefect/_internal/schemas/validators.py +5 -4
  8. prefect/_version.py +3 -3
  9. prefect/blocks/core.py +144 -17
  10. prefect/blocks/system.py +2 -1
  11. prefect/client/orchestration.py +88 -0
  12. prefect/client/schemas/actions.py +5 -5
  13. prefect/client/schemas/filters.py +1 -1
  14. prefect/client/schemas/objects.py +5 -5
  15. prefect/client/schemas/responses.py +1 -2
  16. prefect/client/schemas/schedules.py +1 -1
  17. prefect/client/subscriptions.py +2 -1
  18. prefect/client/utilities.py +15 -1
  19. prefect/context.py +1 -1
  20. prefect/deployments/flow_runs.py +3 -3
  21. prefect/deployments/runner.py +14 -14
  22. prefect/deployments/steps/core.py +3 -1
  23. prefect/deployments/steps/pull.py +60 -12
  24. prefect/events/clients.py +55 -4
  25. prefect/events/filters.py +1 -1
  26. prefect/events/related.py +2 -1
  27. prefect/events/schemas/events.py +1 -1
  28. prefect/events/utilities.py +2 -0
  29. prefect/events/worker.py +8 -0
  30. prefect/flow_engine.py +41 -81
  31. prefect/flow_runs.py +4 -2
  32. prefect/flows.py +4 -6
  33. prefect/results.py +43 -22
  34. prefect/runner/storage.py +3 -3
  35. prefect/serializers.py +28 -24
  36. prefect/settings/models/experiments.py +5 -0
  37. prefect/task_engine.py +34 -26
  38. prefect/task_worker.py +43 -25
  39. prefect/tasks.py +118 -125
  40. prefect/telemetry/instrumentation.py +1 -1
  41. prefect/telemetry/processors.py +10 -7
  42. prefect/telemetry/run_telemetry.py +157 -33
  43. prefect/types/__init__.py +4 -1
  44. prefect/variables.py +127 -19
  45. {prefect_client-3.1.6.dist-info → prefect_client-3.1.7.dist-info}/METADATA +2 -1
  46. {prefect_client-3.1.6.dist-info → prefect_client-3.1.7.dist-info}/RECORD +49 -47
  47. {prefect_client-3.1.6.dist-info → prefect_client-3.1.7.dist-info}/LICENSE +0 -0
  48. {prefect_client-3.1.6.dist-info → prefect_client-3.1.7.dist-info}/WHEEL +0 -0
  49. {prefect_client-3.1.6.dist-info → prefect_client-3.1.7.dist-info}/top_level.txt +0 -0
prefect/results.py CHANGED
@@ -35,10 +35,13 @@ from pydantic import (
35
35
  model_validator,
36
36
  )
37
37
  from pydantic_core import PydanticUndefinedType
38
- from pydantic_extra_types.pendulum_dt import DateTime
39
38
  from typing_extensions import ParamSpec, Self
40
39
 
41
40
  import prefect
41
+ from prefect._experimental.lineage import (
42
+ emit_result_read_event,
43
+ emit_result_write_event,
44
+ )
42
45
  from prefect._internal.compatibility import deprecated
43
46
  from prefect._internal.compatibility.deprecated import deprecated_field
44
47
  from prefect.blocks.core import Block
@@ -57,6 +60,7 @@ from prefect.locking.protocol import LockManager
57
60
  from prefect.logging import get_logger
58
61
  from prefect.serializers import PickleSerializer, Serializer
59
62
  from prefect.settings.context import get_current_settings
63
+ from prefect.types import DateTime
60
64
  from prefect.utilities.annotations import NotSet
61
65
  from prefect.utilities.asyncutils import sync_compatible
62
66
  from prefect.utilities.pydantic import get_dispatch_key, lookup_type, register_base_type
@@ -129,7 +133,7 @@ async def resolve_result_storage(
129
133
  elif isinstance(result_storage, Path):
130
134
  storage_block = LocalFileSystem(basepath=str(result_storage))
131
135
  elif isinstance(result_storage, str):
132
- storage_block = await Block.load(result_storage, client=client)
136
+ storage_block = await Block.aload(result_storage, client=client)
133
137
  storage_block_id = storage_block._block_document_id
134
138
  assert storage_block_id is not None, "Loaded storage blocks must have ids"
135
139
  elif isinstance(result_storage, UUID):
@@ -168,7 +172,7 @@ async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
168
172
  default_block = settings.tasks.scheduling.default_storage_block
169
173
 
170
174
  if default_block is not None:
171
- return await Block.load(default_block)
175
+ return await Block.aload(default_block)
172
176
 
173
177
  # otherwise, use the local file system
174
178
  basepath = settings.results.local_storage_path
@@ -232,6 +236,10 @@ def _format_user_supplied_storage_key(key: str) -> str:
232
236
  T = TypeVar("T")
233
237
 
234
238
 
239
+ def default_cache() -> LRUCache[str, "ResultRecord[Any]"]:
240
+ return LRUCache(maxsize=1000)
241
+
242
+
235
243
  def result_storage_discriminator(x: Any) -> str:
236
244
  if isinstance(x, dict):
237
245
  if "block_type_slug" in x:
@@ -284,7 +292,7 @@ class ResultStore(BaseModel):
284
292
  cache_result_in_memory: bool = Field(default=True)
285
293
  serializer: Serializer = Field(default_factory=get_default_result_serializer)
286
294
  storage_key_fn: Callable[[], str] = Field(default=DEFAULT_STORAGE_KEY_FN)
287
- cache: LRUCache = Field(default_factory=lambda: LRUCache(maxsize=1000))
295
+ cache: LRUCache[str, "ResultRecord[Any]"] = Field(default_factory=default_cache)
288
296
 
289
297
  # Deprecated fields
290
298
  persist_result: Optional[bool] = Field(default=None)
@@ -319,7 +327,7 @@ class ResultStore(BaseModel):
319
327
  return self.model_copy(update=update)
320
328
 
321
329
  @sync_compatible
322
- async def update_for_task(self: Self, task: "Task") -> Self:
330
+ async def update_for_task(self: Self, task: "Task[P, R]") -> Self:
323
331
  """
324
332
  Create a new result store for a task.
325
333
 
@@ -446,8 +454,15 @@ class ResultStore(BaseModel):
446
454
  """
447
455
  return await self._exists(key=key, _sync=False)
448
456
 
457
+ def _resolved_key_path(self, key: str) -> str:
458
+ if self.result_storage_block_id is None and hasattr(
459
+ self.result_storage, "_resolve_path"
460
+ ):
461
+ return str(self.result_storage._resolve_path(key))
462
+ return key
463
+
449
464
  @sync_compatible
450
- async def _read(self, key: str, holder: str) -> "ResultRecord":
465
+ async def _read(self, key: str, holder: str) -> "ResultRecord[Any]":
451
466
  """
452
467
  Read a result record from storage.
453
468
 
@@ -465,8 +480,12 @@ class ResultStore(BaseModel):
465
480
  if self.lock_manager is not None and not self.is_lock_holder(key, holder):
466
481
  await self.await_for_lock(key)
467
482
 
468
- if key in self.cache:
469
- return self.cache[key]
483
+ resolved_key_path = self._resolved_key_path(key)
484
+
485
+ if resolved_key_path in self.cache:
486
+ cached_result = self.cache[resolved_key_path]
487
+ await emit_result_read_event(self, resolved_key_path, cached=True)
488
+ return cached_result
470
489
 
471
490
  if self.result_storage is None:
472
491
  self.result_storage = await get_default_result_storage()
@@ -478,31 +497,28 @@ class ResultStore(BaseModel):
478
497
  metadata.storage_key is not None
479
498
  ), "Did not find storage key in metadata"
480
499
  result_content = await self.result_storage.read_path(metadata.storage_key)
481
- result_record = ResultRecord.deserialize_from_result_and_metadata(
500
+ result_record: ResultRecord[
501
+ Any
502
+ ] = ResultRecord.deserialize_from_result_and_metadata(
482
503
  result=result_content, metadata=metadata_content
483
504
  )
505
+ await emit_result_read_event(self, resolved_key_path)
484
506
  else:
485
507
  content = await self.result_storage.read_path(key)
486
- result_record = ResultRecord.deserialize(
508
+ result_record: ResultRecord[Any] = ResultRecord.deserialize(
487
509
  content, backup_serializer=self.serializer
488
510
  )
511
+ await emit_result_read_event(self, resolved_key_path)
489
512
 
490
513
  if self.cache_result_in_memory:
491
- if self.result_storage_block_id is None and hasattr(
492
- self.result_storage, "_resolve_path"
493
- ):
494
- cache_key = str(self.result_storage._resolve_path(key))
495
- else:
496
- cache_key = key
497
-
498
- self.cache[cache_key] = result_record
514
+ self.cache[resolved_key_path] = result_record
499
515
  return result_record
500
516
 
501
517
  def read(
502
518
  self,
503
519
  key: str,
504
520
  holder: Optional[str] = None,
505
- ) -> "ResultRecord":
521
+ ) -> "ResultRecord[Any]":
506
522
  """
507
523
  Read a result record from storage.
508
524
 
@@ -520,7 +536,7 @@ class ResultStore(BaseModel):
520
536
  self,
521
537
  key: str,
522
538
  holder: Optional[str] = None,
523
- ) -> "ResultRecord":
539
+ ) -> "ResultRecord[Any]":
524
540
  """
525
541
  Read a result record from storage.
526
542
 
@@ -663,12 +679,13 @@ class ResultStore(BaseModel):
663
679
  base_key,
664
680
  content=result_record.serialize_metadata(),
665
681
  )
682
+ await emit_result_write_event(self, result_record.metadata.storage_key)
666
683
  # Otherwise, write the result metadata and result together
667
684
  else:
668
685
  await self.result_storage.write_path(
669
686
  result_record.metadata.storage_key, content=result_record.serialize()
670
687
  )
671
-
688
+ await emit_result_write_event(self, result_record.metadata.storage_key)
672
689
  if self.cache_result_in_memory:
673
690
  self.cache[key] = result_record
674
691
 
@@ -898,7 +915,11 @@ class ResultStore(BaseModel):
898
915
  )
899
916
 
900
917
  @sync_compatible
901
- async def read_parameters(self, identifier: UUID) -> Dict[str, Any]:
918
+ async def read_parameters(self, identifier: UUID) -> dict[str, Any]:
919
+ if self.result_storage is None:
920
+ raise ValueError(
921
+ "Result store is not configured - must have a result storage block to read parameters"
922
+ )
902
923
  record = ResultRecord.deserialize(
903
924
  await self.result_storage.read_path(f"parameters/{identifier}")
904
925
  )
prefect/runner/storage.py CHANGED
@@ -53,14 +53,14 @@ class RunnerStorage(Protocol):
53
53
  """
54
54
  ...
55
55
 
56
- def to_pull_step(self) -> dict:
56
+ def to_pull_step(self) -> dict[str, Any]:
57
57
  """
58
58
  Returns a dictionary representation of the storage object that can be
59
59
  used as a deployment pull step.
60
60
  """
61
61
  ...
62
62
 
63
- def __eq__(self, __value) -> bool:
63
+ def __eq__(self, __value: Any) -> bool:
64
64
  """
65
65
  Equality check for runner storage objects.
66
66
  """
@@ -69,7 +69,7 @@ class RunnerStorage(Protocol):
69
69
 
70
70
  class GitCredentials(TypedDict, total=False):
71
71
  username: str
72
- access_token: Union[str, Secret]
72
+ access_token: Union[str, Secret[str]]
73
73
 
74
74
 
75
75
  class GitRepository:
prefect/serializers.py CHANGED
@@ -13,7 +13,7 @@ bytes to an object respectively.
13
13
 
14
14
  import abc
15
15
  import base64
16
- from typing import Any, Dict, Generic, Optional, Type
16
+ from typing import Any, Generic, Optional, Type, Union
17
17
 
18
18
  from pydantic import (
19
19
  BaseModel,
@@ -23,7 +23,7 @@ from pydantic import (
23
23
  ValidationError,
24
24
  field_validator,
25
25
  )
26
- from typing_extensions import Literal, Self, TypeVar
26
+ from typing_extensions import Self, TypeVar
27
27
 
28
28
  from prefect._internal.schemas.validators import (
29
29
  cast_type_names_to_serializers,
@@ -54,7 +54,7 @@ def prefect_json_object_encoder(obj: Any) -> Any:
54
54
  }
55
55
 
56
56
 
57
- def prefect_json_object_decoder(result: dict):
57
+ def prefect_json_object_decoder(result: dict[str, Any]):
58
58
  """
59
59
  `JSONDecoder.object_hook` for decoding objects from JSON when previously encoded
60
60
  with `prefect_json_object_encoder`
@@ -80,12 +80,16 @@ class Serializer(BaseModel, Generic[D], abc.ABC):
80
80
  data.setdefault("type", type_string)
81
81
  super().__init__(**data)
82
82
 
83
- def __new__(cls: Type[Self], **kwargs) -> Self:
83
+ def __new__(cls: Type[Self], **kwargs: Any) -> Self:
84
84
  if "type" in kwargs:
85
85
  try:
86
86
  subcls = lookup_type(cls, dispatch_key=kwargs["type"])
87
87
  except KeyError as exc:
88
- raise ValidationError(errors=[exc], model=cls)
88
+ raise ValidationError.from_exception_data(
89
+ title=cls.__name__,
90
+ line_errors=[{"type": str(exc), "input": kwargs["type"]}],
91
+ input_type="python",
92
+ )
89
93
 
90
94
  return super().__new__(subcls)
91
95
  else:
@@ -104,7 +108,7 @@ class Serializer(BaseModel, Generic[D], abc.ABC):
104
108
  model_config = ConfigDict(extra="forbid")
105
109
 
106
110
  @classmethod
107
- def __dispatch_key__(cls) -> str:
111
+ def __dispatch_key__(cls) -> Optional[str]:
108
112
  type_str = cls.model_fields["type"].default
109
113
  return type_str if isinstance(type_str, str) else None
110
114
 
@@ -119,19 +123,15 @@ class PickleSerializer(Serializer):
119
123
  - Wraps pickles in base64 for safe transmission.
120
124
  """
121
125
 
122
- type: Literal["pickle"] = "pickle"
126
+ type: str = Field(default="pickle", frozen=True)
123
127
 
124
128
  picklelib: str = "cloudpickle"
125
129
  picklelib_version: Optional[str] = None
126
130
 
127
131
  @field_validator("picklelib")
128
- def check_picklelib(cls, value):
132
+ def check_picklelib(cls, value: str) -> str:
129
133
  return validate_picklelib(value)
130
134
 
131
- # @model_validator(mode="before")
132
- # def check_picklelib_version(cls, values):
133
- # return validate_picklelib_version(values)
134
-
135
135
  def dumps(self, obj: Any) -> bytes:
136
136
  pickler = from_qualified_name(self.picklelib)
137
137
  blob = pickler.dumps(obj)
@@ -151,7 +151,7 @@ class JSONSerializer(Serializer):
151
151
  Wraps the `json` library to serialize to UTF-8 bytes instead of string types.
152
152
  """
153
153
 
154
- type: Literal["json"] = "json"
154
+ type: str = Field(default="json", frozen=True)
155
155
 
156
156
  jsonlib: str = "json"
157
157
  object_encoder: Optional[str] = Field(
@@ -171,23 +171,27 @@ class JSONSerializer(Serializer):
171
171
  "by our default `object_encoder`."
172
172
  ),
173
173
  )
174
- dumps_kwargs: Dict[str, Any] = Field(default_factory=dict)
175
- loads_kwargs: Dict[str, Any] = Field(default_factory=dict)
174
+ dumps_kwargs: dict[str, Any] = Field(default_factory=dict)
175
+ loads_kwargs: dict[str, Any] = Field(default_factory=dict)
176
176
 
177
177
  @field_validator("dumps_kwargs")
178
- def dumps_kwargs_cannot_contain_default(cls, value):
178
+ def dumps_kwargs_cannot_contain_default(
179
+ cls, value: dict[str, Any]
180
+ ) -> dict[str, Any]:
179
181
  return validate_dump_kwargs(value)
180
182
 
181
183
  @field_validator("loads_kwargs")
182
- def loads_kwargs_cannot_contain_object_hook(cls, value):
184
+ def loads_kwargs_cannot_contain_object_hook(
185
+ cls, value: dict[str, Any]
186
+ ) -> dict[str, Any]:
183
187
  return validate_load_kwargs(value)
184
188
 
185
- def dumps(self, data: Any) -> bytes:
189
+ def dumps(self, obj: Any) -> bytes:
186
190
  json = from_qualified_name(self.jsonlib)
187
191
  kwargs = self.dumps_kwargs.copy()
188
192
  if self.object_encoder:
189
193
  kwargs["default"] = from_qualified_name(self.object_encoder)
190
- result = json.dumps(data, **kwargs)
194
+ result = json.dumps(obj, **kwargs)
191
195
  if isinstance(result, str):
192
196
  # The standard library returns str but others may return bytes directly
193
197
  result = result.encode()
@@ -213,17 +217,17 @@ class CompressedSerializer(Serializer):
213
217
  level: If not null, the level of compression to pass to `compress`.
214
218
  """
215
219
 
216
- type: Literal["compressed"] = "compressed"
220
+ type: str = Field(default="compressed", frozen=True)
217
221
 
218
222
  serializer: Serializer
219
223
  compressionlib: str = "lzma"
220
224
 
221
225
  @field_validator("serializer", mode="before")
222
- def validate_serializer(cls, value):
226
+ def validate_serializer(cls, value: Union[str, Serializer]) -> Serializer:
223
227
  return cast_type_names_to_serializers(value)
224
228
 
225
229
  @field_validator("compressionlib")
226
- def check_compressionlib(cls, value):
230
+ def check_compressionlib(cls, value: str) -> str:
227
231
  return validate_compressionlib(value)
228
232
 
229
233
  def dumps(self, obj: Any) -> bytes:
@@ -242,7 +246,7 @@ class CompressedPickleSerializer(CompressedSerializer):
242
246
  A compressed serializer preconfigured to use the pickle serializer.
243
247
  """
244
248
 
245
- type: Literal["compressed/pickle"] = "compressed/pickle"
249
+ type: str = Field(default="compressed/pickle", frozen=True)
246
250
 
247
251
  serializer: Serializer = Field(default_factory=PickleSerializer)
248
252
 
@@ -252,6 +256,6 @@ class CompressedJSONSerializer(CompressedSerializer):
252
256
  A compressed serializer preconfigured to use the json serializer.
253
257
  """
254
258
 
255
- type: Literal["compressed/json"] = "compressed/json"
259
+ type: str = Field(default="compressed/json", frozen=True)
256
260
 
257
261
  serializer: Serializer = Field(default_factory=JSONSerializer)
@@ -22,3 +22,8 @@ class ExperimentsSettings(PrefectBaseSettings):
22
22
  default=False,
23
23
  description="Enables sending telemetry to Prefect Cloud.",
24
24
  )
25
+
26
+ lineage_events_enabled: bool = Field(
27
+ default=False,
28
+ description="If `True`, enables emitting lineage events. Set to `False` to disable lineage event emission.",
29
+ )
prefect/task_engine.py CHANGED
@@ -4,7 +4,7 @@ import logging
4
4
  import threading
5
5
  import time
6
6
  from asyncio import CancelledError
7
- from contextlib import ExitStack, asynccontextmanager, contextmanager
7
+ from contextlib import ExitStack, asynccontextmanager, contextmanager, nullcontext
8
8
  from dataclasses import dataclass, field
9
9
  from functools import partial
10
10
  from textwrap import dedent
@@ -523,7 +523,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
523
523
  self.set_state(terminal_state)
524
524
  self._return_value = result
525
525
 
526
- self._telemetry.end_span_on_success(terminal_state.message)
526
+ self._telemetry.end_span_on_success()
527
527
  return result
528
528
 
529
529
  def handle_retry(self, exc: Exception) -> bool:
@@ -586,7 +586,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
586
586
  self.record_terminal_state_timing(state)
587
587
  self.set_state(state)
588
588
  self._raised = exc
589
- self._telemetry.end_span_on_failure(state.message)
589
+ self._telemetry.end_span_on_failure(state.message if state else None)
590
590
 
591
591
  def handle_timeout(self, exc: TimeoutError) -> None:
592
592
  if not self.handle_retry(exc):
@@ -600,6 +600,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
600
600
  message=message,
601
601
  name="TimedOut",
602
602
  )
603
+ self.record_terminal_state_timing(state)
603
604
  self.set_state(state)
604
605
  self._raised = exc
605
606
 
@@ -611,7 +612,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
611
612
  self.set_state(state, force=True)
612
613
  self._raised = exc
613
614
  self._telemetry.record_exception(exc)
614
- self._telemetry.end_span_on_failure(state.message)
615
+ self._telemetry.end_span_on_failure(state.message if state else None)
615
616
 
616
617
  @contextmanager
617
618
  def setup_run_context(self, client: Optional[SyncPrefectClient] = None):
@@ -669,7 +670,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
669
670
  with SyncClientContext.get_or_create() as client_ctx:
670
671
  self._client = client_ctx.client
671
672
  self._is_started = True
672
- flow_run_context = FlowRunContext.get()
673
+ parent_flow_run_context = FlowRunContext.get()
673
674
  parent_task_run_context = TaskRunContext.get()
674
675
 
675
676
  try:
@@ -678,7 +679,7 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
678
679
  self.task.create_local_run(
679
680
  id=task_run_id,
680
681
  parameters=self.parameters,
681
- flow_run_context=flow_run_context,
682
+ flow_run_context=parent_flow_run_context,
682
683
  parent_task_run_context=parent_task_run_context,
683
684
  wait_for=self.wait_for,
684
685
  extra_task_inputs=dependencies,
@@ -696,11 +697,12 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
696
697
  self.logger.debug(
697
698
  f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
698
699
  )
699
- labels = (
700
- flow_run_context.flow_run.labels if flow_run_context else {}
701
- )
700
+
702
701
  self._telemetry.start_span(
703
- self.task_run, self.parameters, labels
702
+ run=self.task_run,
703
+ name=self.task.name,
704
+ client=self.client,
705
+ parameters=self.parameters,
704
706
  )
705
707
 
706
708
  yield self
@@ -754,7 +756,9 @@ class SyncTaskRunEngine(BaseTaskRunEngine[P, R]):
754
756
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
755
757
  ) -> Generator[None, None, None]:
756
758
  with self.initialize_run(task_run_id=task_run_id, dependencies=dependencies):
757
- with trace.use_span(self._telemetry._span):
759
+ with trace.use_span(
760
+ self._telemetry.span
761
+ ) if self._telemetry.span else nullcontext():
758
762
  self.begin_run()
759
763
  try:
760
764
  yield
@@ -1057,7 +1061,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1057
1061
  await self.set_state(terminal_state)
1058
1062
  self._return_value = result
1059
1063
 
1060
- self._telemetry.end_span_on_success(terminal_state.message)
1064
+ self._telemetry.end_span_on_success()
1061
1065
 
1062
1066
  return result
1063
1067
 
@@ -1134,6 +1138,7 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1134
1138
  message=message,
1135
1139
  name="TimedOut",
1136
1140
  )
1141
+ self.record_terminal_state_timing(state)
1137
1142
  await self.set_state(state)
1138
1143
  self._raised = exc
1139
1144
  self._telemetry.end_span_on_failure(state.message)
@@ -1204,15 +1209,16 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1204
1209
  async with AsyncClientContext.get_or_create():
1205
1210
  self._client = get_client()
1206
1211
  self._is_started = True
1207
- flow_run_context = FlowRunContext.get()
1212
+ parent_flow_run_context = FlowRunContext.get()
1213
+ parent_task_run_context = TaskRunContext.get()
1208
1214
 
1209
1215
  try:
1210
1216
  if not self.task_run:
1211
1217
  self.task_run = await self.task.create_local_run(
1212
1218
  id=task_run_id,
1213
1219
  parameters=self.parameters,
1214
- flow_run_context=flow_run_context,
1215
- parent_task_run_context=TaskRunContext.get(),
1220
+ flow_run_context=parent_flow_run_context,
1221
+ parent_task_run_context=parent_task_run_context,
1216
1222
  wait_for=self.wait_for,
1217
1223
  extra_task_inputs=dependencies,
1218
1224
  )
@@ -1229,11 +1235,11 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1229
1235
  f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
1230
1236
  )
1231
1237
 
1232
- labels = (
1233
- flow_run_context.flow_run.labels if flow_run_context else {}
1234
- )
1235
- self._telemetry.start_span(
1236
- self.task_run, self.parameters, labels
1238
+ await self._telemetry.async_start_span(
1239
+ run=self.task_run,
1240
+ name=self.task.name,
1241
+ client=self.client,
1242
+ parameters=self.parameters,
1237
1243
  )
1238
1244
 
1239
1245
  yield self
@@ -1289,7 +1295,9 @@ class AsyncTaskRunEngine(BaseTaskRunEngine[P, R]):
1289
1295
  async with self.initialize_run(
1290
1296
  task_run_id=task_run_id, dependencies=dependencies
1291
1297
  ):
1292
- with trace.use_span(self._telemetry._span):
1298
+ with trace.use_span(
1299
+ self._telemetry.span
1300
+ ) if self._telemetry.span else nullcontext():
1293
1301
  await self.begin_run()
1294
1302
  try:
1295
1303
  yield
@@ -1370,7 +1378,7 @@ def run_task_sync(
1370
1378
  task_run_id: Optional[UUID] = None,
1371
1379
  task_run: Optional[TaskRun] = None,
1372
1380
  parameters: Optional[Dict[str, Any]] = None,
1373
- wait_for: Optional[Iterable[PrefectFuture]] = None,
1381
+ wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
1374
1382
  return_type: Literal["state", "result"] = "result",
1375
1383
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
1376
1384
  context: Optional[Dict[str, Any]] = None,
@@ -1397,7 +1405,7 @@ async def run_task_async(
1397
1405
  task_run_id: Optional[UUID] = None,
1398
1406
  task_run: Optional[TaskRun] = None,
1399
1407
  parameters: Optional[Dict[str, Any]] = None,
1400
- wait_for: Optional[Iterable[PrefectFuture]] = None,
1408
+ wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
1401
1409
  return_type: Literal["state", "result"] = "result",
1402
1410
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
1403
1411
  context: Optional[Dict[str, Any]] = None,
@@ -1424,7 +1432,7 @@ def run_generator_task_sync(
1424
1432
  task_run_id: Optional[UUID] = None,
1425
1433
  task_run: Optional[TaskRun] = None,
1426
1434
  parameters: Optional[Dict[str, Any]] = None,
1427
- wait_for: Optional[Iterable[PrefectFuture]] = None,
1435
+ wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
1428
1436
  return_type: Literal["state", "result"] = "result",
1429
1437
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
1430
1438
  context: Optional[Dict[str, Any]] = None,
@@ -1479,7 +1487,7 @@ async def run_generator_task_async(
1479
1487
  task_run_id: Optional[UUID] = None,
1480
1488
  task_run: Optional[TaskRun] = None,
1481
1489
  parameters: Optional[Dict[str, Any]] = None,
1482
- wait_for: Optional[Iterable[PrefectFuture]] = None,
1490
+ wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
1483
1491
  return_type: Literal["state", "result"] = "result",
1484
1492
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
1485
1493
  context: Optional[Dict[str, Any]] = None,
@@ -1535,7 +1543,7 @@ def run_task(
1535
1543
  task_run_id: Optional[UUID] = None,
1536
1544
  task_run: Optional[TaskRun] = None,
1537
1545
  parameters: Optional[Dict[str, Any]] = None,
1538
- wait_for: Optional[Iterable[PrefectFuture]] = None,
1546
+ wait_for: Optional[Iterable[PrefectFuture[R]]] = None,
1539
1547
  return_type: Literal["state", "result"] = "result",
1540
1548
  dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
1541
1549
  context: Optional[Dict[str, Any]] = None,