prefect-client 3.1.4__py3-none-any.whl → 3.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (96) hide show
  1. prefect/__init__.py +3 -0
  2. prefect/_internal/compatibility/migration.py +1 -1
  3. prefect/_internal/concurrency/api.py +52 -52
  4. prefect/_internal/concurrency/calls.py +59 -35
  5. prefect/_internal/concurrency/cancellation.py +34 -18
  6. prefect/_internal/concurrency/event_loop.py +7 -6
  7. prefect/_internal/concurrency/threads.py +41 -33
  8. prefect/_internal/concurrency/waiters.py +28 -21
  9. prefect/_internal/pydantic/v1_schema.py +2 -2
  10. prefect/_internal/pydantic/v2_schema.py +10 -9
  11. prefect/_internal/schemas/bases.py +10 -11
  12. prefect/_internal/schemas/validators.py +2 -1
  13. prefect/_version.py +3 -3
  14. prefect/automations.py +53 -47
  15. prefect/blocks/abstract.py +12 -10
  16. prefect/blocks/core.py +4 -2
  17. prefect/cache_policies.py +11 -11
  18. prefect/client/__init__.py +3 -1
  19. prefect/client/base.py +36 -37
  20. prefect/client/cloud.py +26 -19
  21. prefect/client/collections.py +2 -2
  22. prefect/client/orchestration.py +366 -277
  23. prefect/client/schemas/__init__.py +24 -0
  24. prefect/client/schemas/actions.py +132 -120
  25. prefect/client/schemas/filters.py +5 -0
  26. prefect/client/schemas/objects.py +113 -85
  27. prefect/client/schemas/responses.py +21 -18
  28. prefect/client/schemas/schedules.py +136 -93
  29. prefect/client/subscriptions.py +28 -14
  30. prefect/client/utilities.py +32 -36
  31. prefect/concurrency/asyncio.py +6 -9
  32. prefect/concurrency/services.py +3 -0
  33. prefect/concurrency/sync.py +35 -5
  34. prefect/context.py +39 -31
  35. prefect/deployments/flow_runs.py +3 -5
  36. prefect/docker/__init__.py +1 -1
  37. prefect/events/schemas/events.py +25 -20
  38. prefect/events/utilities.py +1 -2
  39. prefect/filesystems.py +3 -3
  40. prefect/flow_engine.py +755 -138
  41. prefect/flow_runs.py +3 -3
  42. prefect/flows.py +214 -170
  43. prefect/logging/configuration.py +1 -1
  44. prefect/logging/highlighters.py +1 -2
  45. prefect/logging/loggers.py +30 -20
  46. prefect/main.py +17 -24
  47. prefect/runner/runner.py +43 -21
  48. prefect/runner/server.py +30 -32
  49. prefect/runner/submit.py +3 -6
  50. prefect/runner/utils.py +6 -6
  51. prefect/runtime/flow_run.py +7 -0
  52. prefect/settings/constants.py +2 -2
  53. prefect/settings/legacy.py +1 -1
  54. prefect/settings/models/server/events.py +10 -0
  55. prefect/settings/sources.py +9 -2
  56. prefect/task_engine.py +72 -19
  57. prefect/task_runners.py +2 -2
  58. prefect/tasks.py +46 -33
  59. prefect/telemetry/bootstrap.py +15 -2
  60. prefect/telemetry/run_telemetry.py +107 -0
  61. prefect/transactions.py +14 -14
  62. prefect/types/__init__.py +20 -3
  63. prefect/utilities/_engine.py +96 -0
  64. prefect/utilities/annotations.py +25 -18
  65. prefect/utilities/asyncutils.py +126 -140
  66. prefect/utilities/callables.py +87 -78
  67. prefect/utilities/collections.py +278 -117
  68. prefect/utilities/compat.py +13 -21
  69. prefect/utilities/context.py +6 -5
  70. prefect/utilities/dispatch.py +23 -12
  71. prefect/utilities/dockerutils.py +33 -32
  72. prefect/utilities/engine.py +126 -239
  73. prefect/utilities/filesystem.py +18 -15
  74. prefect/utilities/hashing.py +10 -11
  75. prefect/utilities/importtools.py +40 -27
  76. prefect/utilities/math.py +9 -5
  77. prefect/utilities/names.py +3 -3
  78. prefect/utilities/processutils.py +121 -57
  79. prefect/utilities/pydantic.py +41 -36
  80. prefect/utilities/render_swagger.py +22 -12
  81. prefect/utilities/schema_tools/__init__.py +2 -1
  82. prefect/utilities/schema_tools/hydration.py +50 -43
  83. prefect/utilities/schema_tools/validation.py +52 -42
  84. prefect/utilities/services.py +13 -12
  85. prefect/utilities/templating.py +45 -45
  86. prefect/utilities/text.py +2 -1
  87. prefect/utilities/timeout.py +4 -4
  88. prefect/utilities/urls.py +9 -4
  89. prefect/utilities/visualization.py +46 -24
  90. prefect/variables.py +9 -8
  91. prefect/workers/base.py +18 -10
  92. {prefect_client-3.1.4.dist-info → prefect_client-3.1.6.dist-info}/METADATA +5 -5
  93. {prefect_client-3.1.4.dist-info → prefect_client-3.1.6.dist-info}/RECORD +96 -94
  94. {prefect_client-3.1.4.dist-info → prefect_client-3.1.6.dist-info}/WHEEL +1 -1
  95. {prefect_client-3.1.4.dist-info → prefect_client-3.1.6.dist-info}/LICENSE +0 -0
  96. {prefect_client-3.1.4.dist-info → prefect_client-3.1.6.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  import datetime
2
- from typing import Any, Dict, List, Optional, TypeVar, Union
2
+ from typing import Any, ClassVar, Generic, Optional, TypeVar, Union
3
3
  from uuid import UUID
4
4
 
5
5
  from pydantic import ConfigDict, Field
@@ -9,10 +9,11 @@ from typing_extensions import Literal
9
9
  import prefect.client.schemas.objects as objects
10
10
  from prefect._internal.schemas.bases import ObjectBaseModel, PrefectBaseModel
11
11
  from prefect._internal.schemas.fields import CreatedBy, UpdatedBy
12
+ from prefect.types import KeyValueLabelsField
12
13
  from prefect.utilities.collections import AutoEnum
13
14
  from prefect.utilities.names import generate_slug
14
15
 
15
- R = TypeVar("R")
16
+ T = TypeVar("T")
16
17
 
17
18
 
18
19
  class SetStateStatus(AutoEnum):
@@ -119,7 +120,7 @@ class HistoryResponse(PrefectBaseModel):
119
120
  interval_end: DateTime = Field(
120
121
  default=..., description="The end date of the interval."
121
122
  )
122
- states: List[HistoryResponseState] = Field(
123
+ states: list[HistoryResponseState] = Field(
123
124
  default=..., description="A list of state histories during the interval."
124
125
  )
125
126
 
@@ -129,18 +130,18 @@ StateResponseDetails = Union[
129
130
  ]
130
131
 
131
132
 
132
- class OrchestrationResult(PrefectBaseModel):
133
+ class OrchestrationResult(PrefectBaseModel, Generic[T]):
133
134
  """
134
135
  A container for the output of state orchestration.
135
136
  """
136
137
 
137
- state: Optional[objects.State]
138
+ state: Optional[objects.State[T]]
138
139
  status: SetStateStatus
139
140
  details: StateResponseDetails
140
141
 
141
142
 
142
143
  class WorkerFlowRunResponse(PrefectBaseModel):
143
- model_config = ConfigDict(arbitrary_types_allowed=True)
144
+ model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
144
145
 
145
146
  work_pool_id: UUID
146
147
  work_queue_id: UUID
@@ -178,7 +179,7 @@ class FlowRunResponse(ObjectBaseModel):
178
179
  description="The version of the flow executed in this flow run.",
179
180
  examples=["1.0"],
180
181
  )
181
- parameters: Dict[str, Any] = Field(
182
+ parameters: dict[str, Any] = Field(
182
183
  default_factory=dict, description="Parameters for the flow run."
183
184
  )
184
185
  idempotency_key: Optional[str] = Field(
@@ -188,7 +189,7 @@ class FlowRunResponse(ObjectBaseModel):
188
189
  " run is not created multiple times."
189
190
  ),
190
191
  )
191
- context: Dict[str, Any] = Field(
192
+ context: dict[str, Any] = Field(
192
193
  default_factory=dict,
193
194
  description="Additional context for the flow run.",
194
195
  examples=[{"my_var": "my_val"}],
@@ -196,11 +197,12 @@ class FlowRunResponse(ObjectBaseModel):
196
197
  empirical_policy: objects.FlowRunPolicy = Field(
197
198
  default_factory=objects.FlowRunPolicy,
198
199
  )
199
- tags: List[str] = Field(
200
+ tags: list[str] = Field(
200
201
  default_factory=list,
201
202
  description="A list of tags on the flow run",
202
203
  examples=[["tag-1", "tag-2"]],
203
204
  )
205
+ labels: KeyValueLabelsField
204
206
  parent_task_run_id: Optional[UUID] = Field(
205
207
  default=None,
206
208
  description=(
@@ -273,7 +275,7 @@ class FlowRunResponse(ObjectBaseModel):
273
275
  description="The state of the flow run.",
274
276
  examples=["objects.State(type=objects.StateType.COMPLETED)"],
275
277
  )
276
- job_variables: Optional[dict] = Field(
278
+ job_variables: Optional[dict[str, Any]] = Field(
277
279
  default=None, description="Job variables for the flow run."
278
280
  )
279
281
 
@@ -333,26 +335,27 @@ class DeploymentResponse(ObjectBaseModel):
333
335
  default=None,
334
336
  description="The concurrency options for the deployment.",
335
337
  )
336
- schedules: List[objects.DeploymentSchedule] = Field(
338
+ schedules: list[objects.DeploymentSchedule] = Field(
337
339
  default_factory=list, description="A list of schedules for the deployment."
338
340
  )
339
- job_variables: Dict[str, Any] = Field(
341
+ job_variables: dict[str, Any] = Field(
340
342
  default_factory=dict,
341
343
  description="Overrides to apply to flow run infrastructure at runtime.",
342
344
  )
343
- parameters: Dict[str, Any] = Field(
345
+ parameters: dict[str, Any] = Field(
344
346
  default_factory=dict,
345
347
  description="Parameters for flow runs scheduled by the deployment.",
346
348
  )
347
- pull_steps: Optional[List[dict]] = Field(
349
+ pull_steps: Optional[list[dict[str, Any]]] = Field(
348
350
  default=None,
349
351
  description="Pull steps for cloning and running this deployment.",
350
352
  )
351
- tags: List[str] = Field(
353
+ tags: list[str] = Field(
352
354
  default_factory=list,
353
355
  description="A list of tags for the deployment",
354
356
  examples=[["tag-1", "tag-2"]],
355
357
  )
358
+ labels: KeyValueLabelsField
356
359
  work_queue_name: Optional[str] = Field(
357
360
  default=None,
358
361
  description=(
@@ -364,7 +367,7 @@ class DeploymentResponse(ObjectBaseModel):
364
367
  default=None,
365
368
  description="The last time the deployment was polled for status updates.",
366
369
  )
367
- parameter_openapi_schema: Optional[Dict[str, Any]] = Field(
370
+ parameter_openapi_schema: Optional[dict[str, Any]] = Field(
368
371
  default=None,
369
372
  description="The parameter schema of the flow, including defaults.",
370
373
  )
@@ -397,7 +400,7 @@ class DeploymentResponse(ObjectBaseModel):
397
400
  default=None,
398
401
  description="Optional information about the updater of this deployment.",
399
402
  )
400
- work_queue_id: UUID = Field(
403
+ work_queue_id: Optional[UUID] = Field(
401
404
  default=None,
402
405
  description=(
403
406
  "The id of the work pool queue to which this deployment is assigned."
@@ -420,7 +423,7 @@ class DeploymentResponse(ObjectBaseModel):
420
423
 
421
424
 
422
425
  class MinimalConcurrencyLimitResponse(PrefectBaseModel):
423
- model_config = ConfigDict(extra="ignore")
426
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="ignore")
424
427
 
425
428
  id: UUID
426
429
  name: str
@@ -3,13 +3,13 @@ Schedule schemas
3
3
  """
4
4
 
5
5
  import datetime
6
- from typing import Annotated, Any, Optional, Union
6
+ from typing import TYPE_CHECKING, Annotated, Any, ClassVar, Optional, Union
7
7
 
8
8
  import dateutil
9
9
  import dateutil.rrule
10
+ import dateutil.tz
10
11
  import pendulum
11
12
  from pydantic import AfterValidator, ConfigDict, Field, field_validator, model_validator
12
- from pydantic_extra_types.pendulum_dt import DateTime
13
13
  from typing_extensions import TypeAlias, TypeGuard
14
14
 
15
15
  from prefect._internal.schemas.bases import PrefectBaseModel
@@ -20,6 +20,14 @@ from prefect._internal.schemas.validators import (
20
20
  validate_rrule_string,
21
21
  )
22
22
 
23
+ if TYPE_CHECKING:
24
+ # type checkers have difficulty accepting that
25
+ # pydantic_extra_types.pendulum_dt and pendulum.DateTime can be used
26
+ # together.
27
+ DateTime = pendulum.DateTime
28
+ else:
29
+ from pydantic_extra_types.pendulum_dt import DateTime
30
+
23
31
  MAX_ITERATIONS = 1000
24
32
  # approx. 1 years worth of RDATEs + buffer
25
33
  MAX_RRULE_LENGTH = 6500
@@ -54,7 +62,7 @@ class IntervalSchedule(PrefectBaseModel):
54
62
  timezone (str, optional): a valid timezone string
55
63
  """
56
64
 
57
- model_config = ConfigDict(extra="forbid", exclude_none=True)
65
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
58
66
 
59
67
  interval: datetime.timedelta = Field(gt=datetime.timedelta(0))
60
68
  anchor_date: Annotated[DateTime, AfterValidator(default_anchor_date)] = Field(
@@ -68,6 +76,19 @@ class IntervalSchedule(PrefectBaseModel):
68
76
  self.timezone = default_timezone(self.timezone, self.model_dump())
69
77
  return self
70
78
 
79
+ if TYPE_CHECKING:
80
+ # The model accepts str or datetime values for `anchor_date`
81
+ def __init__(
82
+ self,
83
+ /,
84
+ interval: datetime.timedelta,
85
+ anchor_date: Optional[
86
+ Union[pendulum.DateTime, datetime.datetime, str]
87
+ ] = None,
88
+ timezone: Optional[str] = None,
89
+ ) -> None:
90
+ ...
91
+
71
92
 
72
93
  class CronSchedule(PrefectBaseModel):
73
94
  """
@@ -94,7 +115,7 @@ class CronSchedule(PrefectBaseModel):
94
115
 
95
116
  """
96
117
 
97
- model_config = ConfigDict(extra="forbid")
118
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
98
119
 
99
120
  cron: str = Field(default=..., examples=["0 0 * * *"])
100
121
  timezone: Optional[str] = Field(default=None, examples=["America/New_York"])
@@ -107,18 +128,36 @@ class CronSchedule(PrefectBaseModel):
107
128
 
108
129
  @field_validator("timezone")
109
130
  @classmethod
110
- def valid_timezone(cls, v):
131
+ def valid_timezone(cls, v: Optional[str]) -> str:
111
132
  return default_timezone(v)
112
133
 
113
134
  @field_validator("cron")
114
135
  @classmethod
115
- def valid_cron_string(cls, v):
136
+ def valid_cron_string(cls, v: str) -> str:
116
137
  return validate_cron_string(v)
117
138
 
118
139
 
119
140
  DEFAULT_ANCHOR_DATE = pendulum.date(2020, 1, 1)
120
141
 
121
142
 
143
+ def _rrule_dt(
144
+ rrule: dateutil.rrule.rrule, name: str = "_dtstart"
145
+ ) -> Optional[datetime.datetime]:
146
+ return getattr(rrule, name, None)
147
+
148
+
149
+ def _rrule(
150
+ rruleset: dateutil.rrule.rruleset, name: str = "_rrule"
151
+ ) -> list[dateutil.rrule.rrule]:
152
+ return getattr(rruleset, name, [])
153
+
154
+
155
+ def _rdates(
156
+ rrule: dateutil.rrule.rruleset, name: str = "_rdate"
157
+ ) -> list[datetime.datetime]:
158
+ return getattr(rrule, name, [])
159
+
160
+
122
161
  class RRuleSchedule(PrefectBaseModel):
123
162
  """
124
163
  RRule schedule, based on the iCalendar standard
@@ -139,7 +178,7 @@ class RRuleSchedule(PrefectBaseModel):
139
178
  timezone (str, optional): a valid timezone string
140
179
  """
141
180
 
142
- model_config = ConfigDict(extra="forbid")
181
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
143
182
 
144
183
  rrule: str
145
184
  timezone: Optional[str] = Field(
@@ -148,58 +187,60 @@ class RRuleSchedule(PrefectBaseModel):
148
187
 
149
188
  @field_validator("rrule")
150
189
  @classmethod
151
- def validate_rrule_str(cls, v):
190
+ def validate_rrule_str(cls, v: str) -> str:
152
191
  return validate_rrule_string(v)
153
192
 
154
193
  @classmethod
155
- def from_rrule(cls, rrule: dateutil.rrule.rrule):
194
+ def from_rrule(
195
+ cls, rrule: Union[dateutil.rrule.rrule, dateutil.rrule.rruleset]
196
+ ) -> "RRuleSchedule":
156
197
  if isinstance(rrule, dateutil.rrule.rrule):
157
- if rrule._dtstart.tzinfo is not None:
158
- timezone = rrule._dtstart.tzinfo.name
198
+ dtstart = _rrule_dt(rrule)
199
+ if dtstart and dtstart.tzinfo is not None:
200
+ timezone = dtstart.tzinfo.tzname(dtstart)
159
201
  else:
160
202
  timezone = "UTC"
161
203
  return RRuleSchedule(rrule=str(rrule), timezone=timezone)
162
- elif isinstance(rrule, dateutil.rrule.rruleset):
163
- dtstarts = [rr._dtstart for rr in rrule._rrule if rr._dtstart is not None]
164
- unique_dstarts = set(pendulum.instance(d).in_tz("UTC") for d in dtstarts)
165
- unique_timezones = set(d.tzinfo for d in dtstarts if d.tzinfo is not None)
166
-
167
- if len(unique_timezones) > 1:
168
- raise ValueError(
169
- f"rruleset has too many dtstart timezones: {unique_timezones}"
170
- )
171
-
172
- if len(unique_dstarts) > 1:
173
- raise ValueError(f"rruleset has too many dtstarts: {unique_dstarts}")
174
-
175
- if unique_dstarts and unique_timezones:
176
- timezone = dtstarts[0].tzinfo.name
177
- else:
178
- timezone = "UTC"
179
-
180
- rruleset_string = ""
181
- if rrule._rrule:
182
- rruleset_string += "\n".join(str(r) for r in rrule._rrule)
183
- if rrule._exrule:
184
- rruleset_string += "\n" if rruleset_string else ""
185
- rruleset_string += "\n".join(str(r) for r in rrule._exrule).replace(
186
- "RRULE", "EXRULE"
187
- )
188
- if rrule._rdate:
189
- rruleset_string += "\n" if rruleset_string else ""
190
- rruleset_string += "RDATE:" + ",".join(
191
- rd.strftime("%Y%m%dT%H%M%SZ") for rd in rrule._rdate
192
- )
193
- if rrule._exdate:
194
- rruleset_string += "\n" if rruleset_string else ""
195
- rruleset_string += "EXDATE:" + ",".join(
196
- exd.strftime("%Y%m%dT%H%M%SZ") for exd in rrule._exdate
197
- )
198
- return RRuleSchedule(rrule=rruleset_string, timezone=timezone)
204
+ rrules = _rrule(rrule)
205
+ dtstarts = [dts for rr in rrules if (dts := _rrule_dt(rr)) is not None]
206
+ unique_dstarts = set(pendulum.instance(d).in_tz("UTC") for d in dtstarts)
207
+ unique_timezones = set(d.tzinfo for d in dtstarts if d.tzinfo is not None)
208
+
209
+ if len(unique_timezones) > 1:
210
+ raise ValueError(
211
+ f"rruleset has too many dtstart timezones: {unique_timezones}"
212
+ )
213
+
214
+ if len(unique_dstarts) > 1:
215
+ raise ValueError(f"rruleset has too many dtstarts: {unique_dstarts}")
216
+
217
+ if unique_dstarts and unique_timezones:
218
+ [unique_tz] = unique_timezones
219
+ timezone = unique_tz.tzname(dtstarts[0])
199
220
  else:
200
- raise ValueError(f"Invalid RRule object: {rrule}")
201
-
202
- def to_rrule(self) -> dateutil.rrule.rrule:
221
+ timezone = "UTC"
222
+
223
+ rruleset_string = ""
224
+ if rrules:
225
+ rruleset_string += "\n".join(str(r) for r in rrules)
226
+ if exrule := _rrule(rrule, "_exrule"):
227
+ rruleset_string += "\n" if rruleset_string else ""
228
+ rruleset_string += "\n".join(str(r) for r in exrule).replace(
229
+ "RRULE", "EXRULE"
230
+ )
231
+ if rdates := _rdates(rrule):
232
+ rruleset_string += "\n" if rruleset_string else ""
233
+ rruleset_string += "RDATE:" + ",".join(
234
+ rd.strftime("%Y%m%dT%H%M%SZ") for rd in rdates
235
+ )
236
+ if exdates := _rdates(rrule, "_exdate"):
237
+ rruleset_string += "\n" if rruleset_string else ""
238
+ rruleset_string += "EXDATE:" + ",".join(
239
+ exd.strftime("%Y%m%dT%H%M%SZ") for exd in exdates
240
+ )
241
+ return RRuleSchedule(rrule=rruleset_string, timezone=timezone)
242
+
243
+ def to_rrule(self) -> Union[dateutil.rrule.rrule, dateutil.rrule.rruleset]:
203
244
  """
204
245
  Since rrule doesn't properly serialize/deserialize timezones, we localize dates
205
246
  here
@@ -211,51 +252,53 @@ class RRuleSchedule(PrefectBaseModel):
211
252
  )
212
253
  timezone = dateutil.tz.gettz(self.timezone)
213
254
  if isinstance(rrule, dateutil.rrule.rrule):
214
- kwargs = dict(dtstart=rrule._dtstart.replace(tzinfo=timezone))
215
- if rrule._until:
255
+ dtstart = _rrule_dt(rrule)
256
+ assert dtstart is not None
257
+ kwargs: dict[str, Any] = dict(dtstart=dtstart.replace(tzinfo=timezone))
258
+ if until := _rrule_dt(rrule, "_until"):
216
259
  kwargs.update(
217
- until=rrule._until.replace(tzinfo=timezone),
260
+ until=until.replace(tzinfo=timezone),
218
261
  )
219
262
  return rrule.replace(**kwargs)
220
- elif isinstance(rrule, dateutil.rrule.rruleset):
221
- # update rrules
222
- localized_rrules = []
223
- for rr in rrule._rrule:
224
- kwargs = dict(dtstart=rr._dtstart.replace(tzinfo=timezone))
225
- if rr._until:
226
- kwargs.update(
227
- until=rr._until.replace(tzinfo=timezone),
228
- )
229
- localized_rrules.append(rr.replace(**kwargs))
230
- rrule._rrule = localized_rrules
231
-
232
- # update exrules
233
- localized_exrules = []
234
- for exr in rrule._exrule:
235
- kwargs = dict(dtstart=exr._dtstart.replace(tzinfo=timezone))
236
- if exr._until:
237
- kwargs.update(
238
- until=exr._until.replace(tzinfo=timezone),
239
- )
240
- localized_exrules.append(exr.replace(**kwargs))
241
- rrule._exrule = localized_exrules
242
-
243
- # update rdates
244
- localized_rdates = []
245
- for rd in rrule._rdate:
246
- localized_rdates.append(rd.replace(tzinfo=timezone))
247
- rrule._rdate = localized_rdates
248
-
249
- # update exdates
250
- localized_exdates = []
251
- for exd in rrule._exdate:
252
- localized_exdates.append(exd.replace(tzinfo=timezone))
253
- rrule._exdate = localized_exdates
254
-
255
- return rrule
263
+
264
+ # update rrules
265
+ localized_rrules: list[dateutil.rrule.rrule] = []
266
+ for rr in _rrule(rrule):
267
+ dtstart = _rrule_dt(rr)
268
+ assert dtstart is not None
269
+ kwargs: dict[str, Any] = dict(dtstart=dtstart.replace(tzinfo=timezone))
270
+ if until := _rrule_dt(rr, "_until"):
271
+ kwargs.update(until=until.replace(tzinfo=timezone))
272
+ localized_rrules.append(rr.replace(**kwargs))
273
+ setattr(rrule, "_rrule", localized_rrules)
274
+
275
+ # update exrules
276
+ localized_exrules: list[dateutil.rrule.rruleset] = []
277
+ for exr in _rrule(rrule, "_exrule"):
278
+ dtstart = _rrule_dt(exr)
279
+ assert dtstart is not None
280
+ kwargs = dict(dtstart=dtstart.replace(tzinfo=timezone))
281
+ if until := _rrule_dt(exr, "_until"):
282
+ kwargs.update(until=until.replace(tzinfo=timezone))
283
+ localized_exrules.append(exr.replace(**kwargs))
284
+ setattr(rrule, "_exrule", localized_exrules)
285
+
286
+ # update rdates
287
+ localized_rdates: list[datetime.datetime] = []
288
+ for rd in _rdates(rrule):
289
+ localized_rdates.append(rd.replace(tzinfo=timezone))
290
+ setattr(rrule, "_rdate", localized_rdates)
291
+
292
+ # update exdates
293
+ localized_exdates: list[datetime.datetime] = []
294
+ for exd in _rdates(rrule, "_exdate"):
295
+ localized_exdates.append(exd.replace(tzinfo=timezone))
296
+ setattr(rrule, "_exdate", localized_exdates)
297
+
298
+ return rrule
256
299
 
257
300
  @field_validator("timezone")
258
- def valid_timezone(cls, v):
301
+ def valid_timezone(cls, v: Optional[str]) -> str:
259
302
  """
260
303
  Validate that the provided timezone is a valid IANA timezone.
261
304
 
@@ -277,7 +320,7 @@ class RRuleSchedule(PrefectBaseModel):
277
320
 
278
321
 
279
322
  class NoSchedule(PrefectBaseModel):
280
- model_config = ConfigDict(extra="forbid")
323
+ model_config: ClassVar[ConfigDict] = ConfigDict(extra="forbid")
281
324
 
282
325
 
283
326
  SCHEDULE_TYPES: TypeAlias = Union[
@@ -326,7 +369,7 @@ def construct_schedule(
326
369
  if isinstance(interval, (int, float)):
327
370
  interval = datetime.timedelta(seconds=interval)
328
371
  if not anchor_date:
329
- anchor_date = DateTime.now()
372
+ anchor_date = pendulum.DateTime.now()
330
373
  schedule = IntervalSchedule(
331
374
  interval=interval, anchor_date=anchor_date, timezone=timezone
332
375
  )
@@ -1,5 +1,7 @@
1
1
  import asyncio
2
- from typing import Any, Dict, Generic, Iterable, Optional, Type, TypeVar
2
+ from collections.abc import Iterable
3
+ from logging import Logger
4
+ from typing import Any, Generic, Optional, TypeVar
3
5
 
4
6
  import orjson
5
7
  import websockets
@@ -11,7 +13,7 @@ from prefect._internal.schemas.bases import IDBaseModel
11
13
  from prefect.logging import get_logger
12
14
  from prefect.settings import PREFECT_API_KEY
13
15
 
14
- logger = get_logger(__name__)
16
+ logger: Logger = get_logger(__name__)
15
17
 
16
18
  S = TypeVar("S", bound=IDBaseModel)
17
19
 
@@ -19,7 +21,7 @@ S = TypeVar("S", bound=IDBaseModel)
19
21
  class Subscription(Generic[S]):
20
22
  def __init__(
21
23
  self,
22
- model: Type[S],
24
+ model: type[S],
23
25
  path: str,
24
26
  keys: Iterable[str],
25
27
  client_id: Optional[str] = None,
@@ -27,27 +29,33 @@ class Subscription(Generic[S]):
27
29
  ):
28
30
  self.model = model
29
31
  self.client_id = client_id
30
- base_url = base_url.replace("http", "ws", 1)
31
- self.subscription_url = f"{base_url}{path}"
32
+ base_url = base_url.replace("http", "ws", 1) if base_url else None
33
+ self.subscription_url: str = f"{base_url}{path}"
32
34
 
33
- self.keys = list(keys)
35
+ self.keys: list[str] = list(keys)
34
36
 
35
37
  self._connect = websockets.connect(
36
38
  self.subscription_url,
37
- subprotocols=["prefect"],
39
+ subprotocols=[websockets.Subprotocol("prefect")],
38
40
  )
39
41
  self._websocket = None
40
42
 
41
43
  def __aiter__(self) -> Self:
42
44
  return self
43
45
 
46
+ @property
47
+ def websocket(self) -> websockets.WebSocketClientProtocol:
48
+ if not self._websocket:
49
+ raise RuntimeError("Subscription is not connected")
50
+ return self._websocket
51
+
44
52
  async def __anext__(self) -> S:
45
53
  while True:
46
54
  try:
47
55
  await self._ensure_connected()
48
- message = await self._websocket.recv()
56
+ message = await self.websocket.recv()
49
57
 
50
- await self._websocket.send(orjson.dumps({"type": "ack"}).decode())
58
+ await self.websocket.send(orjson.dumps({"type": "ack"}).decode())
51
59
 
52
60
  return self.model.model_validate_json(message)
53
61
  except (
@@ -72,10 +80,10 @@ class Subscription(Generic[S]):
72
80
  ).decode()
73
81
  )
74
82
 
75
- auth: Dict[str, Any] = orjson.loads(await websocket.recv())
83
+ auth: dict[str, Any] = orjson.loads(await websocket.recv())
76
84
  assert auth["type"] == "auth_success", auth.get("message")
77
85
 
78
- message = {"type": "subscribe", "keys": self.keys}
86
+ message: dict[str, Any] = {"type": "subscribe", "keys": self.keys}
79
87
  if self.client_id:
80
88
  message.update({"client_id": self.client_id})
81
89
 
@@ -84,13 +92,19 @@ class Subscription(Generic[S]):
84
92
  AssertionError,
85
93
  websockets.exceptions.ConnectionClosedError,
86
94
  ) as e:
87
- if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
95
+ if isinstance(e, AssertionError) or (
96
+ e.rcvd and e.rcvd.code == WS_1008_POLICY_VIOLATION
97
+ ):
88
98
  if isinstance(e, AssertionError):
89
99
  reason = e.args[0]
90
- elif isinstance(e, websockets.exceptions.ConnectionClosedError):
100
+ elif e.rcvd and e.rcvd.reason:
91
101
  reason = e.rcvd.reason
102
+ else:
103
+ reason = "unknown"
104
+ else:
105
+ reason = None
92
106
 
93
- if isinstance(e, AssertionError) or e.rcvd.code == WS_1008_POLICY_VIOLATION:
107
+ if reason:
94
108
  raise Exception(
95
109
  "Unable to authenticate to the subscription. Please "
96
110
  "ensure the provided `PREFECT_API_KEY` you are using is "