prefect-client 3.1.5__py3-none-any.whl → 3.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. prefect/__init__.py +3 -0
  2. prefect/_internal/compatibility/migration.py +1 -1
  3. prefect/_internal/concurrency/api.py +52 -52
  4. prefect/_internal/concurrency/calls.py +59 -35
  5. prefect/_internal/concurrency/cancellation.py +34 -18
  6. prefect/_internal/concurrency/event_loop.py +7 -6
  7. prefect/_internal/concurrency/threads.py +41 -33
  8. prefect/_internal/concurrency/waiters.py +28 -21
  9. prefect/_internal/pydantic/v1_schema.py +2 -2
  10. prefect/_internal/pydantic/v2_schema.py +10 -9
  11. prefect/_internal/schemas/bases.py +9 -7
  12. prefect/_internal/schemas/validators.py +2 -1
  13. prefect/_version.py +3 -3
  14. prefect/automations.py +53 -47
  15. prefect/blocks/abstract.py +12 -10
  16. prefect/blocks/core.py +4 -2
  17. prefect/cache_policies.py +11 -11
  18. prefect/client/__init__.py +3 -1
  19. prefect/client/base.py +36 -37
  20. prefect/client/cloud.py +26 -19
  21. prefect/client/collections.py +2 -2
  22. prefect/client/orchestration.py +342 -273
  23. prefect/client/schemas/__init__.py +24 -0
  24. prefect/client/schemas/actions.py +123 -116
  25. prefect/client/schemas/objects.py +110 -81
  26. prefect/client/schemas/responses.py +18 -18
  27. prefect/client/schemas/schedules.py +136 -93
  28. prefect/client/subscriptions.py +28 -14
  29. prefect/client/utilities.py +32 -36
  30. prefect/concurrency/asyncio.py +6 -9
  31. prefect/concurrency/sync.py +35 -5
  32. prefect/context.py +39 -31
  33. prefect/deployments/flow_runs.py +3 -5
  34. prefect/docker/__init__.py +1 -1
  35. prefect/events/schemas/events.py +25 -20
  36. prefect/events/utilities.py +1 -2
  37. prefect/filesystems.py +3 -3
  38. prefect/flow_engine.py +61 -21
  39. prefect/flow_runs.py +3 -3
  40. prefect/flows.py +214 -170
  41. prefect/logging/configuration.py +1 -1
  42. prefect/logging/highlighters.py +1 -2
  43. prefect/logging/loggers.py +30 -20
  44. prefect/main.py +17 -24
  45. prefect/runner/runner.py +43 -21
  46. prefect/runner/server.py +30 -32
  47. prefect/runner/submit.py +3 -6
  48. prefect/runner/utils.py +6 -6
  49. prefect/runtime/flow_run.py +7 -0
  50. prefect/settings/constants.py +2 -2
  51. prefect/settings/legacy.py +1 -1
  52. prefect/settings/models/server/events.py +10 -0
  53. prefect/task_engine.py +72 -19
  54. prefect/task_runners.py +2 -2
  55. prefect/tasks.py +46 -33
  56. prefect/telemetry/bootstrap.py +15 -2
  57. prefect/telemetry/run_telemetry.py +107 -0
  58. prefect/transactions.py +14 -14
  59. prefect/types/__init__.py +1 -4
  60. prefect/utilities/_engine.py +96 -0
  61. prefect/utilities/annotations.py +25 -18
  62. prefect/utilities/asyncutils.py +126 -140
  63. prefect/utilities/callables.py +87 -78
  64. prefect/utilities/collections.py +278 -117
  65. prefect/utilities/compat.py +13 -21
  66. prefect/utilities/context.py +6 -5
  67. prefect/utilities/dispatch.py +23 -12
  68. prefect/utilities/dockerutils.py +33 -32
  69. prefect/utilities/engine.py +126 -239
  70. prefect/utilities/filesystem.py +18 -15
  71. prefect/utilities/hashing.py +10 -11
  72. prefect/utilities/importtools.py +40 -27
  73. prefect/utilities/math.py +9 -5
  74. prefect/utilities/names.py +3 -3
  75. prefect/utilities/processutils.py +121 -57
  76. prefect/utilities/pydantic.py +41 -36
  77. prefect/utilities/render_swagger.py +22 -12
  78. prefect/utilities/schema_tools/__init__.py +2 -1
  79. prefect/utilities/schema_tools/hydration.py +50 -43
  80. prefect/utilities/schema_tools/validation.py +52 -42
  81. prefect/utilities/services.py +13 -12
  82. prefect/utilities/templating.py +45 -45
  83. prefect/utilities/text.py +2 -1
  84. prefect/utilities/timeout.py +4 -4
  85. prefect/utilities/urls.py +9 -4
  86. prefect/utilities/visualization.py +46 -24
  87. prefect/variables.py +9 -8
  88. prefect/workers/base.py +15 -8
  89. {prefect_client-3.1.5.dist-info → prefect_client-3.1.6.dist-info}/METADATA +4 -2
  90. {prefect_client-3.1.5.dist-info → prefect_client-3.1.6.dist-info}/RECORD +93 -91
  91. {prefect_client-3.1.5.dist-info → prefect_client-3.1.6.dist-info}/LICENSE +0 -0
  92. {prefect_client-3.1.5.dist-info → prefect_client-3.1.6.dist-info}/WHEEL +0 -0
  93. {prefect_client-3.1.5.dist-info → prefect_client-3.1.6.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import datetime
2
- from typing import Any, Dict, List, Optional, TypeVar, Union
2
+ from typing import Any, ClassVar, Generic, Optional, TypeVar, Union
3
3
  from uuid import UUID
4
4
 
5
5
  from pydantic import ConfigDict, Field
@@ -13,7 +13,7 @@ from prefect.types import KeyValueLabelsField
13
13
  from prefect.utilities.collections import AutoEnum
14
14
  from prefect.utilities.names import generate_slug
15
15
 
16
- R = TypeVar("R")
16
+ T = TypeVar("T")
17
17
 
18
18
 
19
19
  class SetStateStatus(AutoEnum):
@@ -120,7 +120,7 @@ class HistoryResponse(PrefectBaseModel):
120
120
  interval_end: DateTime = Field(
121
121
  default=..., description="The end date of the interval."
122
122
  )
123
- states: List[HistoryResponseState] = Field(
123
+ states: list[HistoryResponseState] = Field(
124
124
  default=..., description="A list of state histories during the interval."
125
125
  )
126
126
 
@@ -130,18 +130,18 @@ StateResponseDetails = Union[
130
130
  ]
131
131
 
132
132
 
133
- class OrchestrationResult(PrefectBaseModel):
133
+ class OrchestrationResult(PrefectBaseModel, Generic[T]):
134
134
  """
135
135
  A container for the output of state orchestration.
136
136
  """
137
137
 
138
- state: Optional[objects.State]
138
+ state: Optional[objects.State[T]]
139
139
  status: SetStateStatus
140
140
  details: StateResponseDetails
141
141
 
142
142
 
143
143
  class WorkerFlowRunResponse(PrefectBaseModel):
144
- model_config = ConfigDict(arbitrary_types_allowed=True)
144
+ model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
145
145
 
146
146
  work_pool_id: UUID
147
147
  work_queue_id: UUID
@@ -179,7 +179,7 @@ class FlowRunResponse(ObjectBaseModel):
179
179
  description="The version of the flow executed in this flow run.",
180
180
  examples=["1.0"],
181
181
  )
182
- parameters: Dict[str, Any] = Field(
182
+ parameters: dict[str, Any] = Field(
183
183
  default_factory=dict, description="Parameters for the flow run."
184
184
  )
185
185
  idempotency_key: Optional[str] = Field(
@@ -189,7 +189,7 @@ class FlowRunResponse(ObjectBaseModel):
189
189
  " run is not created multiple times."
190
190
  ),
191
191
  )
192
- context: Dict[str, Any] = Field(
192
+ context: dict[str, Any] = Field(
193
193
  default_factory=dict,
194
194
  description="Additional context for the flow run.",
195
195
  examples=[{"my_var": "my_val"}],
@@ -197,7 +197,7 @@ class FlowRunResponse(ObjectBaseModel):
197
197
  empirical_policy: objects.FlowRunPolicy = Field(
198
198
  default_factory=objects.FlowRunPolicy,
199
199
  )
200
- tags: List[str] = Field(
200
+ tags: list[str] = Field(
201
201
  default_factory=list,
202
202
  description="A list of tags on the flow run",
203
203
  examples=[["tag-1", "tag-2"]],
@@ -275,7 +275,7 @@ class FlowRunResponse(ObjectBaseModel):
275
275
  description="The state of the flow run.",
276
276
  examples=["objects.State(type=objects.StateType.COMPLETED)"],
277
277
  )
278
- job_variables: Optional[dict] = Field(
278
+ job_variables: Optional[dict[str, Any]] = Field(
279
279
  default=None, description="Job variables for the flow run."
280
280
  )
281
281
 
@@ -335,22 +335,22 @@ class DeploymentResponse(ObjectBaseModel):
335
335
  default=None,
336
336
  description="The concurrency options for the deployment.",
337
337
  )
338
- schedules: List[objects.DeploymentSchedule] = Field(
338
+ schedules: list[objects.DeploymentSchedule] = Field(
339
339
  default_factory=list, description="A list of schedules for the deployment."
340
340
  )
341
- job_variables: Dict[str, Any] = Field(
341
+ job_variables: dict[str, Any] = Field(
342
342
  default_factory=dict,
343
343
  description="Overrides to apply to flow run infrastructure at runtime.",
344
344
  )
345
- parameters: Dict[str, Any] = Field(
345
+ parameters: dict[str, Any] = Field(
346
346
  default_factory=dict,
347
347
  description="Parameters for flow runs scheduled by the deployment.",
348
348
  )
349
- pull_steps: Optional[List[dict]] = Field(
349
+ pull_steps: Optional[list[dict[str, Any]]] = Field(
350
350
  default=None,
351
351
  description="Pull steps for cloning and running this deployment.",
352
352
  )
353
- tags: List[str] = Field(
353
+ tags: list[str] = Field(
354
354
  default_factory=list,
355
355
  description="A list of tags for the deployment",
356
356
  examples=[["tag-1", "tag-2"]],
@@ -367,7 +367,7 @@ class DeploymentResponse(ObjectBaseModel):
367
367
  default=None,
368
368
  description="The last time the deployment was polled for status updates.",
369
369
  )
370
- parameter_openapi_schema: Optional[Dict[str, Any]] = Field(
370
+ parameter_openapi_schema: Optional[dict[str, Any]] = Field(
371
371
  default=None,
372
372
  description="The parameter schema of the flow, including defaults.",
373
373
  )
@@ -400,7 +400,7 @@ class DeploymentResponse(ObjectBaseModel):
400
400
  default=None,
401
401
  description="Optional information about the updater of this deployment.",
402
402
  )
403
- work_queue_id: UUID = Field(
403
+ work_queue_id: Optional[UUID] = Field(
404
404
  default=None,
405
405
  description=(
406
406
  "The id of the work pool queue to which this deployment is assigned."
@@ -423,7 +423,7 @@ class DeploymentResponse(ObjectBaseModel):
423
423
 
424
424
 
425
425
  class MinimalConcurrencyLimitResponse(PrefectBaseModel):
426
- model_config = ConfigDict(extra="ignore")
426
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore")
427
427
 
428
428
  id: UUID
429
429
  name: str
@@ -3,13 +3,13 @@ Schedule schemas
3
3
  """
4
4
 
5
5
  import datetime
6
- from typing import Annotated, Any, Optional, Union
6
+ from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Optional, Union
7
7
 
8
8
  import dateutil
9
9
  import dateutil.rrule
10
+ import dateutil.tz
10
11
  import pendulum
11
12
  from pydantic import AfterValidator, ConfigDict, Field, field_validator, model_validator
12
- from pydantic_extra_types.pendulum_dt import DateTime
13
13
  from typing_extensions import TypeAlias, TypeGuard
14
14
 
15
15
  from prefect._internal.schemas.bases import PrefectBaseModel
@@ -20,6 +20,14 @@ from prefect._internal.schemas.validators import (
20
20
  validate_rrule_string,
21
21
  )
22
22
 
23
+ if TYPE_CHECKING:
24
+ # type checkers have difficulty accepting that
25
+ # pydantic_extra_types.pendulum_dt and pendulum.DateTime can be used
26
+ # together.
27
+ DateTime = pendulum.DateTime
28
+ else:
29
+ from pydantic_extra_types.pendulum_dt import DateTime
30
+
23
31
  MAX_ITERATIONS = 1000
24
32
  # approx. 1 years worth of RDATEs + buffer
25
33
  MAX_RRULE_LENGTH = 6500
@@ -54,7 +62,7 @@ class IntervalSchedule(PrefectBaseModel):
54
62
  timezone (str, optional): a valid timezone string
55
63
  """
56
64
 
57
- model_config = ConfigDict(extra="forbid", exclude_none=True)
65
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
58
66
 
59
67
  interval: datetime.timedelta = Field(gt=datetime.timedelta(0))
60
68
  anchor_date: Annotated[DateTime, AfterValidator(default_anchor_date)] = Field(
@@ -68,6 +76,19 @@ class IntervalSchedule(PrefectBaseModel):
68
76
  self.timezone = default_timezone(self.timezone, self.model_dump())
69
77
  return self
70
78
 
79
+ if TYPE_CHECKING:
80
+ # The model accepts str or datetime values for `anchor_date`
81
+ def __init__(
82
+ self,
83
+ /,
84
+ interval: datetime.timedelta,
85
+ anchor_date: Optional[
86
+ Union[pendulum.DateTime, datetime.datetime, str]
87
+ ] = None,
88
+ timezone: Optional[str] = None,
89
+ ) -> None:
90
+ ...
91
+
71
92
 
72
93
  class CronSchedule(PrefectBaseModel):
73
94
  """
@@ -94,7 +115,7 @@ class CronSchedule(PrefectBaseModel):
94
115
 
95
116
  """
96
117
 
97
- model_config = ConfigDict(extra="forbid")
118
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
98
119
 
99
120
  cron: str = Field(default=..., examples=["0 0 * * *"])
100
121
  timezone: Optional[str] = Field(default=None, examples=["America/New_York"])
@@ -107,18 +128,36 @@ class CronSchedule(PrefectBaseModel):
107
128
 
108
129
  @field_validator("timezone")
109
130
  @classmethod
110
- def valid_timezone(cls, v):
131
+ def valid_timezone(cls, v: Optional[str]) -> str:
111
132
  return default_timezone(v)
112
133
 
113
134
  @field_validator("cron")
114
135
  @classmethod
115
- def valid_cron_string(cls, v):
136
+ def valid_cron_string(cls, v: str) -> str:
116
137
  return validate_cron_string(v)
117
138
 
118
139
 
119
140
  DEFAULT_ANCHOR_DATE = pendulum.date(2020, 1, 1)
120
141
 
121
142
 
143
+ def _rrule_dt(
144
+ rrule: dateutil.rrule.rrule, name: str = "_dtstart"
145
+ ) -> Optional[datetime.datetime]:
146
+ return getattr(rrule, name, None)
147
+
148
+
149
+ def _rrule(
150
+ rruleset: dateutil.rrule.rruleset, name: str = "_rrule"
151
+ ) -> list[dateutil.rrule.rrule]:
152
+ return getattr(rruleset, name, [])
153
+
154
+
155
+ def _rdates(
156
+ rrule: dateutil.rrule.rruleset, name: str = "_rdate"
157
+ ) -> list[datetime.datetime]:
158
+ return getattr(rrule, name, [])
159
+
160
+
122
161
  class RRuleSchedule(PrefectBaseModel):
123
162
  """
124
163
  RRule schedule, based on the iCalendar standard
@@ -139,7 +178,7 @@ class RRuleSchedule(PrefectBaseModel):
139
178
  timezone (str, optional): a valid timezone string
140
179
  """
141
180
 
142
- model_config = ConfigDict(extra="forbid")
181
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
143
182
 
144
183
  rrule: str
145
184
  timezone: Optional[str] = Field(
@@ -148,58 +187,60 @@ class RRuleSchedule(PrefectBaseModel):
148
187
 
149
188
  @field_validator("rrule")
150
189
  @classmethod
151
- def validate_rrule_str(cls, v):
190
+ def validate_rrule_str(cls, v: str) -> str:
152
191
  return validate_rrule_string(v)
153
192
 
154
193
  @classmethod
155
- def from_rrule(cls, rrule: dateutil.rrule.rrule):
194
+ def from_rrule(
195
+ cls, rrule: Union[dateutil.rrule.rrule, dateutil.rrule.rruleset]
196
+ ) -> "RRuleSchedule":
156
197
  if isinstance(rrule, dateutil.rrule.rrule):
157
- if rrule._dtstart.tzinfo is not None:
158
- timezone = rrule._dtstart.tzinfo.name
198
+ dtstart = _rrule_dt(rrule)
199
+ if dtstart and dtstart.tzinfo is not None:
200
+ timezone = dtstart.tzinfo.tzname(dtstart)
159
201
  else:
160
202
  timezone = "UTC"
161
203
  return RRuleSchedule(rrule=str(rrule), timezone=timezone)
162
- elif isinstance(rrule, dateutil.rrule.rruleset):
163
- dtstarts = [rr._dtstart for rr in rrule._rrule if rr._dtstart is not None]
164
- unique_dstarts = set(pendulum.instance(d).in_tz("UTC") for d in dtstarts)
165
- unique_timezones = set(d.tzinfo for d in dtstarts if d.tzinfo is not None)
166
-
167
- if len(unique_timezones) > 1:
168
- raise ValueError(
169
- f"rruleset has too many dtstart timezones: {unique_timezones}"
170
- )
171
-
172
- if len(unique_dstarts) > 1:
173
- raise ValueError(f"rruleset has too many dtstarts: {unique_dstarts}")
174
-
175
- if unique_dstarts and unique_timezones:
176
- timezone = dtstarts[0].tzinfo.name
177
- else:
178
- timezone = "UTC"
179
-
180
- rruleset_string = ""
181
- if rrule._rrule:
182
- rruleset_string += "\n".join(str(r) for r in rrule._rrule)
183
- if rrule._exrule:
184
- rruleset_string += "\n" if rruleset_string else ""
185
- rruleset_string += "\n".join(str(r) for r in rrule._exrule).replace(
186
- "RRULE", "EXRULE"
187
- )
188
- if rrule._rdate:
189
- rruleset_string += "\n" if rruleset_string else ""
190
- rruleset_string += "RDATE:" + ",".join(
191
- rd.strftime("%Y%m%dT%H%M%SZ") for rd in rrule._rdate
192
- )
193
- if rrule._exdate:
194
- rruleset_string += "\n" if rruleset_string else ""
195
- rruleset_string += "EXDATE:" + ",".join(
196
- exd.strftime("%Y%m%dT%H%M%SZ") for exd in rrule._exdate
197
- )
198
- return RRuleSchedule(rrule=rruleset_string, timezone=timezone)
204
+ rrules = _rrule(rrule)
205
+ dtstarts = [dts for rr in rrules if (dts := _rrule_dt(rr)) is not None]
206
+ unique_dstarts = set(pendulum.instance(d).in_tz("UTC") for d in dtstarts)
207
+ unique_timezones = set(d.tzinfo for d in dtstarts if d.tzinfo is not None)
208
+
209
+ if len(unique_timezones) > 1:
210
+ raise ValueError(
211
+ f"rruleset has too many dtstart timezones: {unique_timezones}"
212
+ )
213
+
214
+ if len(unique_dstarts) > 1:
215
+ raise ValueError(f"rruleset has too many dtstarts: {unique_dstarts}")
216
+
217
+ if unique_dstarts and unique_timezones:
218
+ [unique_tz] = unique_timezones
219
+ timezone = unique_tz.tzname(dtstarts[0])
199
220
  else:
200
- raise ValueError(f"Invalid RRule object: {rrule}")
201
-
202
- def to_rrule(self) -> dateutil.rrule.rrule:
221
+ timezone = "UTC"
222
+
223
+ rruleset_string = ""
224
+ if rrules:
225
+ rruleset_string += "\n".join(str(r) for r in rrules)
226
+ if exrule := _rrule(rrule, "_exrule"):
227
+ rruleset_string += "\n" if rruleset_string else ""
228
+ rruleset_string += "\n".join(str(r) for r in exrule).replace(
229
+ "RRULE", "EXRULE"
230
+ )
231
+ if rdates := _rdates(rrule):
232
+ rruleset_string += "\n" if rruleset_string else ""
233
+ rruleset_string += "RDATE:" + ",".join(
234
+ rd.strftime("%Y%m%dT%H%M%SZ") for rd in rdates
235
+ )
236
+ if exdates := _rdates(rrule, "_exdate"):
237
+ rruleset_string += "\n" if rruleset_string else ""
238
+ rruleset_string += "EXDATE:" + ",".join(
239
+ exd.strftime("%Y%m%dT%H%M%SZ") for exd in exdates
240
+ )
241
+ return RRuleSchedule(rrule=rruleset_string, timezone=timezone)
242
+
243
+ def to_rrule(self) -> Union[dateutil.rrule.rrule, dateutil.rrule.rruleset]:
203
244
  """
204
245
  Since rrule doesn't properly serialize/deserialize timezones, we localize dates
205
246
  here
@@ -211,51 +252,53 @@ class RRuleSchedule(PrefectBaseModel):
211
252
  )
212
253
  timezone = dateutil.tz.gettz(self.timezone)
213
254
  if isinstance(rrule, dateutil.rrule.rrule):
214
- kwargs = dict(dtstart=rrule._dtstart.replace(tzinfo=timezone))
215
- if rrule._until:
255
+ dtstart = _rrule_dt(rrule)
256
+ assert dtstart is not None
257
+ kwargs: dict[str, Any] = dict(dtstart=dtstart.replace(tzinfo=timezone))
258
+ if until := _rrule_dt(rrule, "_until"):
216
259
  kwargs.update(
217
- until=rrule._until.replace(tzinfo=timezone),
260
+ until=until.replace(tzinfo=timezone),
218
261
  )
219
262
  return rrule.replace(**kwargs)
220
- elif isinstance(rrule, dateutil.rrule.rruleset):
221
- # update rrules
222
- localized_rrules = []
223
- for rr in rrule._rrule:
224
- kwargs = dict(dtstart=rr._dtstart.replace(tzinfo=timezone))
225
- if rr._until:
226
- kwargs.update(
227
- until=rr._until.replace(tzinfo=timezone),
228
- )
229
- localized_rrules.append(rr.replace(**kwargs))
230
- rrule._rrule = localized_rrules
231
-
232
- # update exrules
233
- localized_exrules = []
234
- for exr in rrule._exrule:
235
- kwargs = dict(dtstart=exr._dtstart.replace(tzinfo=timezone))
236
- if exr._until:
237
- kwargs.update(
238
- until=exr._until.replace(tzinfo=timezone),
239
- )
240
- localized_exrules.append(exr.replace(**kwargs))
241
- rrule._exrule = localized_exrules
242
-
243
- # update rdates
244
- localized_rdates = []
245
- for rd in rrule._rdate:
246
- localized_rdates.append(rd.replace(tzinfo=timezone))
247
- rrule._rdate = localized_rdates
248
-
249
- # update exdates
250
- localized_exdates = []
251
- for exd in rrule._exdate:
252
- localized_exdates.append(exd.replace(tzinfo=timezone))
253
- rrule._exdate = localized_exdates
254
-
255
- return rrule
263
+
264
+ # update rrules
265
+ localized_rrules: list[dateutil.rrule.rrule] = []
266
+ for rr in _rrule(rrule):
267
+ dtstart = _rrule_dt(rr)
268
+ assert dtstart is not None
269
+ kwargs: dict[str, Any] = dict(dtstart=dtstart.replace(tzinfo=timezone))
270
+ if until := _rrule_dt(rr, "_until"):
271
+ kwargs.update(until=until.replace(tzinfo=timezone))
272
+ localized_rrules.append(rr.replace(**kwargs))
273
+ setattr(rrule, "_rrule", localized_rrules)
274
+
275
+ # update exrules
276
+ localized_exrules: list[dateutil.rrule.rruleset] = []
277
+ for exr in _rrule(rrule, "_exrule"):
278
+ dtstart = _rrule_dt(exr)
279
+ assert dtstart is not None
280
+ kwargs = dict(dtstart=dtstart.replace(tzinfo=timezone))
281
+ if until := _rrule_dt(exr, "_until"):
282
+ kwargs.update(until=until.replace(tzinfo=timezone))
283
+ localized_exrules.append(exr.replace(**kwargs))
284
+ setattr(rrule, "_exrule", localized_exrules)
285
+
286
+ # update rdates
287
+ localized_rdates: list[datetime.datetime] = []
288
+ for rd in _rdates(rrule):
289
+ localized_rdates.append(rd.replace(tzinfo=timezone))
290
+ setattr(rrule, "_rdate", localized_rdates)
291
+
292
+ # update exdates
293
+ localized_exdates: list[datetime.datetime] = []
294
+ for exd in _rdates(rrule, "_exdate"):
295
+ localized_exdates.append(exd.replace(tzinfo=timezone))
296
+ setattr(rrule, "_exdate", localized_exdates)
297
+
298
+ return rrule
256
299
 
257
300
  @field_validator("timezone")
258
- def valid_timezone(cls, v):
301
+ def valid_timezone(cls, v: Optional[str]) -> str:
259
302
  """
260
303
  Validate that the provided timezone is a valid IANA timezone.
261
304
 
@@ -277,7 +320,7 @@ class RRuleSchedule(PrefectBaseModel):
277
320
 
278
321
 
279
322
  class NoSchedule(PrefectBaseModel):
280
- model_config = ConfigDict(extra="forbid")
323
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
281
324
 
282
325
 
283
326
  SCHEDULE_TYPES: TypeAlias = Union[
@@ -326,7 +369,7 @@ def construct_schedule(
326
369
  if isinstance(interval, (int, float)):
327
370
  interval = datetime.timedelta(seconds=interval)
328
371
  if not anchor_date:
329
- anchor_date = DateTime.now()
372
+ anchor_date = pendulum.DateTime.now()
330
373
  schedule = IntervalSchedule(
331
374
  interval=interval, anchor_date=anchor_date, timezone=timezone
332
375
  )
@@ -1,5 +1,7 @@
1
1
  import asyncio
2
- from typing import Any, Dict, Generic, Iterable, Optional, Type, TypeVar
2
+ from collections.abc import Iterable
3
+ from logging import Logger
4
+ from typing import Any, Generic, Optional, TypeVar
3
5
 
4
6
  import orjson
5
7
  import websockets
@@ -11,7 +13,7 @@ from prefect._internal.schemas.bases import IDBaseModel
11
13
  from prefect.logging import get_logger
12
14
  from prefect.settings import PREFECT_API_KEY
13
15
 
14
- logger = get_logger(__name__)
16
+ logger: Logger = get_logger(__name__)
15
17
 
16
18
  S = TypeVar("S", bound=IDBaseModel)
17
19
 
@@ -19,7 +21,7 @@ S = TypeVar("S", bound=IDBaseModel)
19
21
  class Subscription(Generic[S]):
20
22
  def __init__(
21
23
  self,
22
- model: Type[S],
24
+ model: type[S],
23
25
  path: str,
24
26
  keys: Iterable[str],
25
27
  client_id: Optional[str] = None,
@@ -27,27 +29,33 @@ class Subscription(Generic[S]):
27
29
  ):
28
30
  self.model = model
29
31
  self.client_id = client_id
30
- base_url = base_url.replace("http", "ws", 1)
31
- self.subscription_url = f"{base_url}{path}"
32
+ base_url = base_url.replace("http", "ws", 1) if base_url else None
33
+ self.subscription_url: str = f"{base_url}{path}"
32
34
 
33
- self.keys = list(keys)
35
+ self.keys: list[str] = list(keys)
34
36
 
35
37
  self._connect = websockets.connect(
36
38
  self.subscription_url,
37
- subprotocols=["prefect"],
39
+ subprotocols=[websockets.Subprotocol("prefect")],
38
40
  )
39
41
  self._websocket = None
40
42
 
41
43
  def __aiter__(self) -> Self:
42
44
  return self
43
45
 
46
+ @property
47
+ def websocket(self) -> websockets.WebSocketClientProtocol:
48
+ if not self._websocket:
49
+ raise RuntimeError("Subscription is not connected")
50
+ return self._websocket
51
+
44
52
  async def __anext__(self) -> S:
45
53
  while True:
46
54
  try:
47
55
  await self._ensure_connected()
48
- message = await self._websocket.recv()
56
+ message = await self.websocket.recv()
49
57
 
50
- await self._websocket.send(orjson.dumps({"type": "ack"}).decode())
58
+ await self.websocket.send(orjson.dumps({"type": "ack"}).decode())
51
59
 
52
60
  return self.model.model_validate_json(message)
53
61
  except (
@@ -72,10 +80,10 @@ class Subscription(Generic[S]):
72
80
  ).decode()
73
81
  )
74
82
 
75
- auth: Dict[str, Any] = orjson.loads(await websocket.recv())
83
+ auth: dict[str, Any] = orjson.loads(await websocket.recv())
76
84
  assert auth["type"] == "auth_success", auth.get("message")
77
85
 
78
- message = {"type": "subscribe", "keys": self.keys}
86
+ message: dict[str, Any] = {"type": "subscribe", "keys": self.keys}
79
87
  if self.client_id:
80
88
  message.update({"client_id": self.client_id})
81
89
 
@@ -84,13 +92,19 @@ class Subscription(Generic[S]):
84
92
  AssertionError,
85
93
  websockets.exceptions.ConnectionClosedError,
86
94
  ) as e:
87
- if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
95
+ if isinstance(e, AssertionError) or (
96
+ e.rcvd and e.rcvd.code == WS_1008_POLICY_VIOLATION
97
+ ):
88
98
  if isinstance(e, AssertionError):
89
99
  reason = e.args[0]
90
- elif isinstance(e, websockets.exceptions.ConnectionClosedError):
100
+ elif e.rcvd and e.rcvd.reason:
91
101
  reason = e.rcvd.reason
102
+ else:
103
+ reason = "unknown"
104
+ else:
105
+ reason = None
92
106
 
93
- if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
107
+ if reason:
94
108
  raise Exception(
95
109
  "Unable to authenticate to the subscription. Please "
96
110
  "ensure the provided `PREFECT_API_KEY` you are using is "
@@ -5,31 +5,32 @@ Utilities for working with clients.
5
5
  # This module must not import from `prefect.client` when it is imported to avoid
6
6
  # circular imports for decorators such as `inject_client` which are widely used.
7
7
 
8
+ from collections.abc import Awaitable, Coroutine
8
9
  from functools import wraps
9
- from typing import (
10
- TYPE_CHECKING,
11
- Any,
12
- Awaitable,
13
- Callable,
14
- Coroutine,
15
- Optional,
16
- Tuple,
17
- TypeVar,
18
- cast,
19
- )
20
-
21
- from typing_extensions import Concatenate, ParamSpec
10
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
11
+
12
+ from typing_extensions import Concatenate, ParamSpec, TypeGuard, TypeVar
22
13
 
23
14
  if TYPE_CHECKING:
24
- from prefect.client.orchestration import PrefectClient
15
+ from prefect.client.orchestration import PrefectClient, SyncPrefectClient
25
16
 
26
17
  P = ParamSpec("P")
27
- R = TypeVar("R")
18
+ R = TypeVar("R", infer_variance=True)
19
+
20
+
21
+ def _current_async_client(
22
+ client: Union["PrefectClient", "SyncPrefectClient"],
23
+ ) -> TypeGuard["PrefectClient"]:
24
+ """Determine if the client is a PrefectClient instance attached to the current loop"""
25
+ from prefect._internal.concurrency.event_loop import get_running_loop
26
+
27
+ # Only a PrefectClient will have a _loop attribute that is the current loop
28
+ return getattr(client, "_loop", None) == get_running_loop()
28
29
 
29
30
 
30
31
  def get_or_create_client(
31
32
  client: Optional["PrefectClient"] = None,
32
- ) -> Tuple["PrefectClient", bool]:
33
+ ) -> tuple["PrefectClient", bool]:
33
34
  """
34
35
  Returns provided client, infers a client from context if available, or creates a new client.
35
36
 
@@ -41,29 +42,22 @@ def get_or_create_client(
41
42
  """
42
43
  if client is not None:
43
44
  return client, True
44
- from prefect._internal.concurrency.event_loop import get_running_loop
45
+
45
46
  from prefect.context import AsyncClientContext, FlowRunContext, TaskRunContext
46
47
 
47
48
  async_client_context = AsyncClientContext.get()
48
49
  flow_run_context = FlowRunContext.get()
49
50
  task_run_context = TaskRunContext.get()
50
51
 
51
- if async_client_context and async_client_context.client._loop == get_running_loop():
52
- return async_client_context.client, True
53
- elif (
54
- flow_run_context
55
- and getattr(flow_run_context.client, "_loop", None) == get_running_loop()
56
- ):
57
- return flow_run_context.client, True
58
- elif (
59
- task_run_context
60
- and getattr(task_run_context.client, "_loop", None) == get_running_loop()
61
- ):
62
- return task_run_context.client, True
63
- else:
64
- from prefect.client.orchestration import get_client as get_httpx_client
52
+ for context in (async_client_context, flow_run_context, task_run_context):
53
+ if context is None:
54
+ continue
55
+ if _current_async_client(context_client := context.client):
56
+ return context_client, True
57
+
58
+ from prefect.client.orchestration import get_client as get_httpx_client
65
59
 
66
- return get_httpx_client(), False
60
+ return get_httpx_client(), False
67
61
 
68
62
 
69
63
  def client_injector(
@@ -90,16 +84,18 @@ def inject_client(
90
84
 
91
85
  @wraps(fn)
92
86
  async def with_injected_client(*args: P.args, **kwargs: P.kwargs) -> R:
93
- client = cast(Optional["PrefectClient"], kwargs.pop("client", None))
94
- client, inferred = get_or_create_client(client)
87
+ given = kwargs.pop("client", None)
88
+ if TYPE_CHECKING:
89
+ assert given is None or isinstance(given, PrefectClient)
90
+ client, inferred = get_or_create_client(given)
95
91
  if not inferred:
96
92
  context = client
97
93
  else:
98
94
  from prefect.utilities.asyncutils import asyncnullcontext
99
95
 
100
- context = asyncnullcontext()
96
+ context = asyncnullcontext(client)
101
97
  async with context as new_client:
102
- kwargs.setdefault("client", new_client or client)
98
+ kwargs |= {"client": new_client}
103
99
  return await fn(*args, **kwargs)
104
100
 
105
101
  return with_injected_client