prefect-client 2.20.2__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (288) hide show
  1. prefect/__init__.py +74 -110
  2. prefect/_internal/compatibility/deprecated.py +6 -115
  3. prefect/_internal/compatibility/experimental.py +4 -79
  4. prefect/_internal/compatibility/migration.py +166 -0
  5. prefect/_internal/concurrency/__init__.py +2 -2
  6. prefect/_internal/concurrency/api.py +1 -35
  7. prefect/_internal/concurrency/calls.py +0 -6
  8. prefect/_internal/concurrency/cancellation.py +0 -3
  9. prefect/_internal/concurrency/event_loop.py +0 -20
  10. prefect/_internal/concurrency/inspection.py +3 -3
  11. prefect/_internal/concurrency/primitives.py +1 -0
  12. prefect/_internal/concurrency/services.py +23 -0
  13. prefect/_internal/concurrency/threads.py +35 -0
  14. prefect/_internal/concurrency/waiters.py +0 -28
  15. prefect/_internal/integrations.py +7 -0
  16. prefect/_internal/pydantic/__init__.py +0 -45
  17. prefect/_internal/pydantic/annotations/pendulum.py +2 -2
  18. prefect/_internal/pydantic/v1_schema.py +21 -22
  19. prefect/_internal/pydantic/v2_schema.py +0 -2
  20. prefect/_internal/pydantic/v2_validated_func.py +18 -23
  21. prefect/_internal/pytz.py +1 -1
  22. prefect/_internal/retries.py +61 -0
  23. prefect/_internal/schemas/bases.py +45 -177
  24. prefect/_internal/schemas/fields.py +1 -43
  25. prefect/_internal/schemas/validators.py +47 -233
  26. prefect/agent.py +3 -695
  27. prefect/artifacts.py +173 -14
  28. prefect/automations.py +39 -4
  29. prefect/blocks/abstract.py +1 -1
  30. prefect/blocks/core.py +423 -164
  31. prefect/blocks/fields.py +2 -57
  32. prefect/blocks/notifications.py +43 -28
  33. prefect/blocks/redis.py +168 -0
  34. prefect/blocks/system.py +67 -20
  35. prefect/blocks/webhook.py +2 -9
  36. prefect/cache_policies.py +239 -0
  37. prefect/client/__init__.py +4 -0
  38. prefect/client/base.py +33 -27
  39. prefect/client/cloud.py +65 -20
  40. prefect/client/collections.py +1 -1
  41. prefect/client/orchestration.py +667 -440
  42. prefect/client/schemas/actions.py +115 -100
  43. prefect/client/schemas/filters.py +46 -52
  44. prefect/client/schemas/objects.py +228 -178
  45. prefect/client/schemas/responses.py +18 -36
  46. prefect/client/schemas/schedules.py +55 -36
  47. prefect/client/schemas/sorting.py +2 -0
  48. prefect/client/subscriptions.py +8 -7
  49. prefect/client/types/flexible_schedule_list.py +11 -0
  50. prefect/client/utilities.py +9 -6
  51. prefect/concurrency/asyncio.py +60 -11
  52. prefect/concurrency/context.py +24 -0
  53. prefect/concurrency/events.py +2 -2
  54. prefect/concurrency/services.py +46 -16
  55. prefect/concurrency/sync.py +51 -7
  56. prefect/concurrency/v1/asyncio.py +143 -0
  57. prefect/concurrency/v1/context.py +27 -0
  58. prefect/concurrency/v1/events.py +61 -0
  59. prefect/concurrency/v1/services.py +116 -0
  60. prefect/concurrency/v1/sync.py +92 -0
  61. prefect/context.py +246 -149
  62. prefect/deployments/__init__.py +33 -18
  63. prefect/deployments/base.py +10 -15
  64. prefect/deployments/deployments.py +2 -1048
  65. prefect/deployments/flow_runs.py +178 -0
  66. prefect/deployments/runner.py +72 -173
  67. prefect/deployments/schedules.py +31 -25
  68. prefect/deployments/steps/__init__.py +0 -1
  69. prefect/deployments/steps/core.py +7 -0
  70. prefect/deployments/steps/pull.py +15 -21
  71. prefect/deployments/steps/utility.py +2 -1
  72. prefect/docker/__init__.py +20 -0
  73. prefect/docker/docker_image.py +82 -0
  74. prefect/engine.py +15 -2466
  75. prefect/events/actions.py +17 -23
  76. prefect/events/cli/automations.py +20 -7
  77. prefect/events/clients.py +142 -80
  78. prefect/events/filters.py +14 -18
  79. prefect/events/related.py +74 -75
  80. prefect/events/schemas/__init__.py +0 -5
  81. prefect/events/schemas/automations.py +55 -46
  82. prefect/events/schemas/deployment_triggers.py +7 -197
  83. prefect/events/schemas/events.py +46 -65
  84. prefect/events/schemas/labelling.py +10 -14
  85. prefect/events/utilities.py +4 -5
  86. prefect/events/worker.py +23 -8
  87. prefect/exceptions.py +15 -0
  88. prefect/filesystems.py +30 -529
  89. prefect/flow_engine.py +827 -0
  90. prefect/flow_runs.py +379 -7
  91. prefect/flows.py +470 -360
  92. prefect/futures.py +382 -331
  93. prefect/infrastructure/__init__.py +5 -26
  94. prefect/infrastructure/base.py +3 -320
  95. prefect/infrastructure/provisioners/__init__.py +5 -3
  96. prefect/infrastructure/provisioners/cloud_run.py +13 -8
  97. prefect/infrastructure/provisioners/container_instance.py +14 -9
  98. prefect/infrastructure/provisioners/ecs.py +10 -8
  99. prefect/infrastructure/provisioners/modal.py +8 -5
  100. prefect/input/__init__.py +4 -0
  101. prefect/input/actions.py +2 -4
  102. prefect/input/run_input.py +9 -9
  103. prefect/logging/formatters.py +2 -4
  104. prefect/logging/handlers.py +9 -14
  105. prefect/logging/loggers.py +5 -5
  106. prefect/main.py +72 -0
  107. prefect/plugins.py +2 -64
  108. prefect/profiles.toml +16 -2
  109. prefect/records/__init__.py +1 -0
  110. prefect/records/base.py +223 -0
  111. prefect/records/filesystem.py +207 -0
  112. prefect/records/memory.py +178 -0
  113. prefect/records/result_store.py +64 -0
  114. prefect/results.py +577 -504
  115. prefect/runner/runner.py +124 -51
  116. prefect/runner/server.py +32 -34
  117. prefect/runner/storage.py +3 -12
  118. prefect/runner/submit.py +2 -10
  119. prefect/runner/utils.py +2 -2
  120. prefect/runtime/__init__.py +1 -0
  121. prefect/runtime/deployment.py +1 -0
  122. prefect/runtime/flow_run.py +40 -5
  123. prefect/runtime/task_run.py +1 -0
  124. prefect/serializers.py +28 -39
  125. prefect/server/api/collections_data/views/aggregate-worker-metadata.json +5 -14
  126. prefect/settings.py +209 -332
  127. prefect/states.py +160 -63
  128. prefect/task_engine.py +1478 -57
  129. prefect/task_runners.py +383 -287
  130. prefect/task_runs.py +240 -0
  131. prefect/task_worker.py +463 -0
  132. prefect/tasks.py +684 -374
  133. prefect/transactions.py +410 -0
  134. prefect/types/__init__.py +72 -86
  135. prefect/types/entrypoint.py +13 -0
  136. prefect/utilities/annotations.py +4 -3
  137. prefect/utilities/asyncutils.py +227 -148
  138. prefect/utilities/callables.py +138 -48
  139. prefect/utilities/collections.py +134 -86
  140. prefect/utilities/dispatch.py +27 -14
  141. prefect/utilities/dockerutils.py +11 -4
  142. prefect/utilities/engine.py +186 -32
  143. prefect/utilities/filesystem.py +4 -5
  144. prefect/utilities/importtools.py +26 -27
  145. prefect/utilities/pydantic.py +128 -38
  146. prefect/utilities/schema_tools/hydration.py +18 -1
  147. prefect/utilities/schema_tools/validation.py +30 -0
  148. prefect/utilities/services.py +35 -9
  149. prefect/utilities/templating.py +12 -2
  150. prefect/utilities/timeout.py +20 -5
  151. prefect/utilities/urls.py +195 -0
  152. prefect/utilities/visualization.py +1 -0
  153. prefect/variables.py +78 -59
  154. prefect/workers/__init__.py +0 -1
  155. prefect/workers/base.py +237 -244
  156. prefect/workers/block.py +5 -226
  157. prefect/workers/cloud.py +6 -0
  158. prefect/workers/process.py +265 -12
  159. prefect/workers/server.py +29 -11
  160. {prefect_client-2.20.2.dist-info → prefect_client-3.0.0.dist-info}/METADATA +30 -26
  161. prefect_client-3.0.0.dist-info/RECORD +201 -0
  162. {prefect_client-2.20.2.dist-info → prefect_client-3.0.0.dist-info}/WHEEL +1 -1
  163. prefect/_internal/pydantic/_base_model.py +0 -51
  164. prefect/_internal/pydantic/_compat.py +0 -82
  165. prefect/_internal/pydantic/_flags.py +0 -20
  166. prefect/_internal/pydantic/_types.py +0 -8
  167. prefect/_internal/pydantic/utilities/config_dict.py +0 -72
  168. prefect/_internal/pydantic/utilities/field_validator.py +0 -150
  169. prefect/_internal/pydantic/utilities/model_construct.py +0 -56
  170. prefect/_internal/pydantic/utilities/model_copy.py +0 -55
  171. prefect/_internal/pydantic/utilities/model_dump.py +0 -136
  172. prefect/_internal/pydantic/utilities/model_dump_json.py +0 -112
  173. prefect/_internal/pydantic/utilities/model_fields.py +0 -50
  174. prefect/_internal/pydantic/utilities/model_fields_set.py +0 -29
  175. prefect/_internal/pydantic/utilities/model_json_schema.py +0 -82
  176. prefect/_internal/pydantic/utilities/model_rebuild.py +0 -80
  177. prefect/_internal/pydantic/utilities/model_validate.py +0 -75
  178. prefect/_internal/pydantic/utilities/model_validate_json.py +0 -68
  179. prefect/_internal/pydantic/utilities/model_validator.py +0 -87
  180. prefect/_internal/pydantic/utilities/type_adapter.py +0 -71
  181. prefect/_vendor/fastapi/__init__.py +0 -25
  182. prefect/_vendor/fastapi/applications.py +0 -946
  183. prefect/_vendor/fastapi/background.py +0 -3
  184. prefect/_vendor/fastapi/concurrency.py +0 -44
  185. prefect/_vendor/fastapi/datastructures.py +0 -58
  186. prefect/_vendor/fastapi/dependencies/__init__.py +0 -0
  187. prefect/_vendor/fastapi/dependencies/models.py +0 -64
  188. prefect/_vendor/fastapi/dependencies/utils.py +0 -877
  189. prefect/_vendor/fastapi/encoders.py +0 -177
  190. prefect/_vendor/fastapi/exception_handlers.py +0 -40
  191. prefect/_vendor/fastapi/exceptions.py +0 -46
  192. prefect/_vendor/fastapi/logger.py +0 -3
  193. prefect/_vendor/fastapi/middleware/__init__.py +0 -1
  194. prefect/_vendor/fastapi/middleware/asyncexitstack.py +0 -25
  195. prefect/_vendor/fastapi/middleware/cors.py +0 -3
  196. prefect/_vendor/fastapi/middleware/gzip.py +0 -3
  197. prefect/_vendor/fastapi/middleware/httpsredirect.py +0 -3
  198. prefect/_vendor/fastapi/middleware/trustedhost.py +0 -3
  199. prefect/_vendor/fastapi/middleware/wsgi.py +0 -3
  200. prefect/_vendor/fastapi/openapi/__init__.py +0 -0
  201. prefect/_vendor/fastapi/openapi/constants.py +0 -2
  202. prefect/_vendor/fastapi/openapi/docs.py +0 -203
  203. prefect/_vendor/fastapi/openapi/models.py +0 -480
  204. prefect/_vendor/fastapi/openapi/utils.py +0 -485
  205. prefect/_vendor/fastapi/param_functions.py +0 -340
  206. prefect/_vendor/fastapi/params.py +0 -453
  207. prefect/_vendor/fastapi/py.typed +0 -0
  208. prefect/_vendor/fastapi/requests.py +0 -4
  209. prefect/_vendor/fastapi/responses.py +0 -40
  210. prefect/_vendor/fastapi/routing.py +0 -1331
  211. prefect/_vendor/fastapi/security/__init__.py +0 -15
  212. prefect/_vendor/fastapi/security/api_key.py +0 -98
  213. prefect/_vendor/fastapi/security/base.py +0 -6
  214. prefect/_vendor/fastapi/security/http.py +0 -172
  215. prefect/_vendor/fastapi/security/oauth2.py +0 -227
  216. prefect/_vendor/fastapi/security/open_id_connect_url.py +0 -34
  217. prefect/_vendor/fastapi/security/utils.py +0 -10
  218. prefect/_vendor/fastapi/staticfiles.py +0 -1
  219. prefect/_vendor/fastapi/templating.py +0 -3
  220. prefect/_vendor/fastapi/testclient.py +0 -1
  221. prefect/_vendor/fastapi/types.py +0 -3
  222. prefect/_vendor/fastapi/utils.py +0 -235
  223. prefect/_vendor/fastapi/websockets.py +0 -7
  224. prefect/_vendor/starlette/__init__.py +0 -1
  225. prefect/_vendor/starlette/_compat.py +0 -28
  226. prefect/_vendor/starlette/_exception_handler.py +0 -80
  227. prefect/_vendor/starlette/_utils.py +0 -88
  228. prefect/_vendor/starlette/applications.py +0 -261
  229. prefect/_vendor/starlette/authentication.py +0 -159
  230. prefect/_vendor/starlette/background.py +0 -43
  231. prefect/_vendor/starlette/concurrency.py +0 -59
  232. prefect/_vendor/starlette/config.py +0 -151
  233. prefect/_vendor/starlette/convertors.py +0 -87
  234. prefect/_vendor/starlette/datastructures.py +0 -707
  235. prefect/_vendor/starlette/endpoints.py +0 -130
  236. prefect/_vendor/starlette/exceptions.py +0 -60
  237. prefect/_vendor/starlette/formparsers.py +0 -276
  238. prefect/_vendor/starlette/middleware/__init__.py +0 -17
  239. prefect/_vendor/starlette/middleware/authentication.py +0 -52
  240. prefect/_vendor/starlette/middleware/base.py +0 -220
  241. prefect/_vendor/starlette/middleware/cors.py +0 -176
  242. prefect/_vendor/starlette/middleware/errors.py +0 -265
  243. prefect/_vendor/starlette/middleware/exceptions.py +0 -74
  244. prefect/_vendor/starlette/middleware/gzip.py +0 -113
  245. prefect/_vendor/starlette/middleware/httpsredirect.py +0 -19
  246. prefect/_vendor/starlette/middleware/sessions.py +0 -82
  247. prefect/_vendor/starlette/middleware/trustedhost.py +0 -64
  248. prefect/_vendor/starlette/middleware/wsgi.py +0 -147
  249. prefect/_vendor/starlette/py.typed +0 -0
  250. prefect/_vendor/starlette/requests.py +0 -328
  251. prefect/_vendor/starlette/responses.py +0 -347
  252. prefect/_vendor/starlette/routing.py +0 -933
  253. prefect/_vendor/starlette/schemas.py +0 -154
  254. prefect/_vendor/starlette/staticfiles.py +0 -248
  255. prefect/_vendor/starlette/status.py +0 -199
  256. prefect/_vendor/starlette/templating.py +0 -231
  257. prefect/_vendor/starlette/testclient.py +0 -804
  258. prefect/_vendor/starlette/types.py +0 -30
  259. prefect/_vendor/starlette/websockets.py +0 -193
  260. prefect/blocks/kubernetes.py +0 -119
  261. prefect/deprecated/__init__.py +0 -0
  262. prefect/deprecated/data_documents.py +0 -350
  263. prefect/deprecated/packaging/__init__.py +0 -12
  264. prefect/deprecated/packaging/base.py +0 -96
  265. prefect/deprecated/packaging/docker.py +0 -146
  266. prefect/deprecated/packaging/file.py +0 -92
  267. prefect/deprecated/packaging/orion.py +0 -80
  268. prefect/deprecated/packaging/serializers.py +0 -171
  269. prefect/events/instrument.py +0 -135
  270. prefect/infrastructure/container.py +0 -824
  271. prefect/infrastructure/kubernetes.py +0 -920
  272. prefect/infrastructure/process.py +0 -289
  273. prefect/manifests.py +0 -20
  274. prefect/new_flow_engine.py +0 -449
  275. prefect/new_task_engine.py +0 -423
  276. prefect/pydantic/__init__.py +0 -76
  277. prefect/pydantic/main.py +0 -39
  278. prefect/software/__init__.py +0 -2
  279. prefect/software/base.py +0 -50
  280. prefect/software/conda.py +0 -199
  281. prefect/software/pip.py +0 -122
  282. prefect/software/python.py +0 -52
  283. prefect/task_server.py +0 -322
  284. prefect_client-2.20.2.dist-info/RECORD +0 -294
  285. /prefect/{_internal/pydantic/utilities → client/types}/__init__.py +0 -0
  286. /prefect/{_vendor → concurrency/v1}/__init__.py +0 -0
  287. {prefect_client-2.20.2.dist-info → prefect_client-3.0.0.dist-info}/LICENSE +0 -0
  288. {prefect_client-2.20.2.dist-info → prefect_client-3.0.0.dist-info}/top_level.txt +0 -0
prefect/task_runs.py ADDED
@@ -0,0 +1,240 @@
1
+ import asyncio
2
+ import atexit
3
+ import threading
4
+ import uuid
5
+ from typing import Callable, Dict, Optional
6
+
7
+ import anyio
8
+ from cachetools import TTLCache
9
+ from typing_extensions import Self
10
+
11
+ from prefect._internal.concurrency.api import create_call, from_async, from_sync
12
+ from prefect._internal.concurrency.threads import get_global_loop
13
+ from prefect.client.schemas.objects import TERMINAL_STATES
14
+ from prefect.events.clients import get_events_subscriber
15
+ from prefect.events.filters import EventFilter, EventNameFilter
16
+ from prefect.logging.loggers import get_logger
17
+
18
+
19
+ class TaskRunWaiter:
20
+ """
21
+ A service used for waiting for a task run to finish.
22
+
23
+ This service listens for task run events and provides a way to wait for a specific
24
+ task run to finish. This is useful for waiting for a task run to finish before
25
+ continuing execution.
26
+
27
+ The service is a singleton and must be started before use. The service will
28
+ automatically start when the first instance is created. A single websocket
29
+ connection is used to listen for task run events.
30
+
31
+ The service can be used to wait for a task run to finish by calling
32
+ `TaskRunWaiter.wait_for_task_run` with the task run ID to wait for. The method
33
+ will return when the task run has finished or the timeout has elapsed.
34
+
35
+ The service will automatically stop when the Python process exits or when the
36
+ global loop thread is stopped.
37
+
38
+ Example:
39
+ ```python
40
+ import asyncio
41
+ from uuid import uuid4
42
+
43
+ from prefect import task
44
+ from prefect.task_engine import run_task_async
45
+ from prefect.task_runs import TaskRunWaiter
46
+
47
+
48
+ @task
49
+ async def test_task():
50
+ await asyncio.sleep(5)
51
+ print("Done!")
52
+
53
+
54
+ async def main():
55
+ task_run_id = uuid4()
56
+ asyncio.create_task(run_task_async(task=test_task, task_run_id=task_run_id))
57
+
58
+ await TaskRunWaiter.wait_for_task_run(task_run_id)
59
+ print("Task run finished")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ asyncio.run(main())
64
+ ```
65
+ """
66
+
67
+ _instance: Optional[Self] = None
68
+ _instance_lock = threading.Lock()
69
+
70
+ def __init__(self):
71
+ self.logger = get_logger("TaskRunWaiter")
72
+ self._consumer_task: Optional[asyncio.Task] = None
73
+ self._observed_completed_task_runs: TTLCache[uuid.UUID, bool] = TTLCache(
74
+ maxsize=10000, ttl=600
75
+ )
76
+ self._completion_events: Dict[uuid.UUID, asyncio.Event] = {}
77
+ self._completion_callbacks: Dict[uuid.UUID, Callable] = {}
78
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
79
+ self._observed_completed_task_runs_lock = threading.Lock()
80
+ self._completion_events_lock = threading.Lock()
81
+ self._started = False
82
+
83
+ def start(self):
84
+ """
85
+ Start the TaskRunWaiter service.
86
+ """
87
+ if self._started:
88
+ return
89
+ self.logger.debug("Starting TaskRunWaiter")
90
+ loop_thread = get_global_loop()
91
+
92
+ if not asyncio.get_running_loop() == loop_thread._loop:
93
+ raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
94
+
95
+ self._loop = loop_thread._loop
96
+
97
+ consumer_started = asyncio.Event()
98
+ self._consumer_task = self._loop.create_task(
99
+ self._consume_events(consumer_started)
100
+ )
101
+ asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop)
102
+
103
+ loop_thread.add_shutdown_call(create_call(self.stop))
104
+ atexit.register(self.stop)
105
+ self._started = True
106
+
107
+ async def _consume_events(self, consumer_started: asyncio.Event):
108
+ async with get_events_subscriber(
109
+ filter=EventFilter(
110
+ event=EventNameFilter(
111
+ name=[
112
+ f"prefect.task-run.{state.name.title()}"
113
+ for state in TERMINAL_STATES
114
+ ],
115
+ )
116
+ )
117
+ ) as subscriber:
118
+ consumer_started.set()
119
+ async for event in subscriber:
120
+ try:
121
+ self.logger.debug(
122
+ f"Received event: {event.resource['prefect.resource.id']}"
123
+ )
124
+ task_run_id = uuid.UUID(
125
+ event.resource["prefect.resource.id"].replace(
126
+ "prefect.task-run.", ""
127
+ )
128
+ )
129
+
130
+ with self._observed_completed_task_runs_lock:
131
+ # Cache the task run ID for a short period of time to avoid
132
+ # unnecessary waits
133
+ self._observed_completed_task_runs[task_run_id] = True
134
+ with self._completion_events_lock:
135
+ # Set the event for the task run ID if it is in the cache
136
+ # so the waiter can wake up the waiting coroutine
137
+ if task_run_id in self._completion_events:
138
+ self._completion_events[task_run_id].set()
139
+ if task_run_id in self._completion_callbacks:
140
+ self._completion_callbacks[task_run_id]()
141
+ except Exception as exc:
142
+ self.logger.error(f"Error processing event: {exc}")
143
+
144
+ def stop(self):
145
+ """
146
+ Stop the TaskRunWaiter service.
147
+ """
148
+ self.logger.debug("Stopping TaskRunWaiter")
149
+ if self._consumer_task:
150
+ self._consumer_task.cancel()
151
+ self._consumer_task = None
152
+ self.__class__._instance = None
153
+ self._started = False
154
+
155
+ @classmethod
156
+ async def wait_for_task_run(
157
+ cls, task_run_id: uuid.UUID, timeout: Optional[float] = None
158
+ ):
159
+ """
160
+ Wait for a task run to finish.
161
+
162
+ Note this relies on a websocket connection to receive events from the server
163
+ and will not work with an ephemeral server.
164
+
165
+ Args:
166
+ task_run_id: The ID of the task run to wait for.
167
+ timeout: The maximum time to wait for the task run to
168
+ finish. Defaults to None.
169
+ """
170
+ instance = cls.instance()
171
+ with instance._observed_completed_task_runs_lock:
172
+ if task_run_id in instance._observed_completed_task_runs:
173
+ return
174
+
175
+ # Need to create event in loop thread to ensure it can be set
176
+ # from the loop thread
177
+ finished_event = await from_async.wait_for_call_in_loop_thread(
178
+ create_call(asyncio.Event)
179
+ )
180
+ with instance._completion_events_lock:
181
+ # Cache the event for the task run ID so the consumer can set it
182
+ # when the event is received
183
+ instance._completion_events[task_run_id] = finished_event
184
+
185
+ try:
186
+ # Now check one more time whether the task run arrived before we start to
187
+ # wait on it, in case it came in while we were setting up the event above.
188
+ with instance._observed_completed_task_runs_lock:
189
+ if task_run_id in instance._observed_completed_task_runs:
190
+ return
191
+
192
+ with anyio.move_on_after(delay=timeout):
193
+ await from_async.wait_for_call_in_loop_thread(
194
+ create_call(finished_event.wait)
195
+ )
196
+ finally:
197
+ with instance._completion_events_lock:
198
+ # Remove the event from the cache after it has been waited on
199
+ instance._completion_events.pop(task_run_id, None)
200
+
201
+ @classmethod
202
+ def add_done_callback(cls, task_run_id: uuid.UUID, callback):
203
+ """
204
+ Add a callback to be called when a task run finishes.
205
+
206
+ Args:
207
+ task_run_id: The ID of the task run to wait for.
208
+ callback: The callback to call when the task run finishes.
209
+ """
210
+ instance = cls.instance()
211
+ with instance._observed_completed_task_runs_lock:
212
+ if task_run_id in instance._observed_completed_task_runs:
213
+ callback()
214
+ return
215
+
216
+ with instance._completion_events_lock:
217
+ # Cache the event for the task run ID so the consumer can set it
218
+ # when the event is received
219
+ instance._completion_callbacks[task_run_id] = callback
220
+
221
+ @classmethod
222
+ def instance(cls):
223
+ """
224
+ Get the singleton instance of TaskRunWaiter.
225
+ """
226
+ with cls._instance_lock:
227
+ if cls._instance is None:
228
+ cls._instance = cls._new_instance()
229
+ return cls._instance
230
+
231
+ @classmethod
232
+ def _new_instance(cls):
233
+ instance = cls()
234
+
235
+ if threading.get_ident() == get_global_loop().thread.ident:
236
+ instance.start()
237
+ else:
238
+ from_sync.call_soon_in_loop_thread(create_call(instance.start)).result()
239
+
240
+ return instance
prefect/task_worker.py ADDED
@@ -0,0 +1,463 @@
1
+ import asyncio
2
+ import inspect
3
+ import os
4
+ import signal
5
+ import socket
6
+ import sys
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from contextlib import AsyncExitStack
9
+ from contextvars import copy_context
10
+ from typing import Optional
11
+ from uuid import UUID
12
+
13
+ import anyio
14
+ import anyio.abc
15
+ import pendulum
16
+ import uvicorn
17
+ from exceptiongroup import BaseExceptionGroup # novermin
18
+ from fastapi import FastAPI
19
+ from websockets.exceptions import InvalidStatusCode
20
+
21
+ from prefect import Task
22
+ from prefect._internal.concurrency.api import create_call, from_sync
23
+ from prefect.cache_policies import DEFAULT, NONE
24
+ from prefect.client.orchestration import get_client
25
+ from prefect.client.schemas.objects import TaskRun
26
+ from prefect.client.subscriptions import Subscription
27
+ from prefect.logging.loggers import get_logger
28
+ from prefect.results import ResultStore, get_or_create_default_task_scheduling_storage
29
+ from prefect.settings import (
30
+ PREFECT_API_URL,
31
+ PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS,
32
+ )
33
+ from prefect.states import Pending
34
+ from prefect.task_engine import run_task_async, run_task_sync
35
+ from prefect.utilities.annotations import NotSet
36
+ from prefect.utilities.asyncutils import asyncnullcontext, sync_compatible
37
+ from prefect.utilities.engine import emit_task_run_state_change_event
38
+ from prefect.utilities.processutils import _register_signal
39
+ from prefect.utilities.services import start_client_metrics_server
40
+ from prefect.utilities.urls import url_for
41
+
42
+ logger = get_logger("task_worker")
43
+
44
+
45
+ class StopTaskWorker(Exception):
46
+ """Raised when the task worker is stopped."""
47
+
48
+ pass
49
+
50
+
51
+ def should_try_to_read_parameters(task: Task, task_run: TaskRun) -> bool:
52
+ """Determines whether a task run should read parameters from the result store."""
53
+ new_enough_state_details = hasattr(
54
+ task_run.state.state_details, "task_parameters_id"
55
+ )
56
+ task_accepts_parameters = bool(inspect.signature(task.fn).parameters)
57
+
58
+ return new_enough_state_details and task_accepts_parameters
59
+
60
+
61
+ class TaskWorker:
62
+ """This class is responsible for serving tasks that may be executed in the background
63
+ by a task runner via the traditional engine machinery.
64
+
65
+ When `start()` is called, the task worker will open a websocket connection to a
66
+ server-side queue of scheduled task runs. When a scheduled task run is found, the
67
+ scheduled task run is submitted to the engine for execution with a minimal `EngineContext`
68
+ so that the task run can be governed by orchestration rules.
69
+
70
+ Args:
71
+ - tasks: A list of tasks to serve. These tasks will be submitted to the engine
72
+ when a scheduled task run is found.
73
+ - limit: The maximum number of tasks that can be run concurrently. Defaults to 10.
74
+ Pass `None` to remove the limit.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ *tasks: Task,
80
+ limit: Optional[int] = 10,
81
+ ):
82
+ self.tasks = []
83
+ for t in tasks:
84
+ if isinstance(t, Task):
85
+ if t.cache_policy in [None, NONE, NotSet]:
86
+ self.tasks.append(
87
+ t.with_options(persist_result=True, cache_policy=DEFAULT)
88
+ )
89
+ else:
90
+ self.tasks.append(t.with_options(persist_result=True))
91
+
92
+ self.task_keys = set(t.task_key for t in tasks if isinstance(t, Task))
93
+
94
+ self._started_at: Optional[pendulum.DateTime] = None
95
+ self.stopping: bool = False
96
+
97
+ self._client = get_client()
98
+ self._exit_stack = AsyncExitStack()
99
+
100
+ if not asyncio.get_event_loop().is_running():
101
+ raise RuntimeError(
102
+ "TaskWorker must be initialized within an async context."
103
+ )
104
+
105
+ self._runs_task_group: anyio.abc.TaskGroup = anyio.create_task_group()
106
+ self._executor = ThreadPoolExecutor(max_workers=limit if limit else None)
107
+ self._limiter = anyio.CapacityLimiter(limit) if limit else None
108
+
109
+ self.in_flight_task_runs: dict[str, dict[UUID, pendulum.DateTime]] = {
110
+ task_key: {} for task_key in self.task_keys
111
+ }
112
+ self.finished_task_runs: dict[str, int] = {
113
+ task_key: 0 for task_key in self.task_keys
114
+ }
115
+
116
+ @property
117
+ def client_id(self) -> str:
118
+ return f"{socket.gethostname()}-{os.getpid()}"
119
+
120
+ @property
121
+ def started_at(self) -> Optional[pendulum.DateTime]:
122
+ return self._started_at
123
+
124
+ @property
125
+ def started(self) -> bool:
126
+ return self._started_at is not None
127
+
128
+ @property
129
+ def limit(self) -> Optional[int]:
130
+ return int(self._limiter.total_tokens) if self._limiter else None
131
+
132
+ @property
133
+ def current_tasks(self) -> Optional[int]:
134
+ return (
135
+ int(self._limiter.borrowed_tokens)
136
+ if self._limiter
137
+ else sum(len(runs) for runs in self.in_flight_task_runs.values())
138
+ )
139
+
140
+ @property
141
+ def available_tasks(self) -> Optional[int]:
142
+ return int(self._limiter.available_tokens) if self._limiter else None
143
+
144
+ def handle_sigterm(self, signum, frame):
145
+ """
146
+ Shuts down the task worker when a SIGTERM is received.
147
+ """
148
+ logger.info("SIGTERM received, initiating graceful shutdown...")
149
+ from_sync.call_in_loop_thread(create_call(self.stop))
150
+
151
+ sys.exit(0)
152
+
153
+ @sync_compatible
154
+ async def start(self) -> None:
155
+ """
156
+ Starts a task worker, which runs the tasks provided in the constructor.
157
+ """
158
+ _register_signal(signal.SIGTERM, self.handle_sigterm)
159
+
160
+ start_client_metrics_server()
161
+
162
+ async with asyncnullcontext() if self.started else self:
163
+ logger.info("Starting task worker...")
164
+ try:
165
+ await self._subscribe_to_task_scheduling()
166
+ except InvalidStatusCode as exc:
167
+ if exc.status_code == 403:
168
+ logger.error(
169
+ "403: Could not establish a connection to the `/task_runs/subscriptions/scheduled`"
170
+ f" endpoint found at:\n\n {PREFECT_API_URL.value()}"
171
+ "\n\nPlease double-check the values of your"
172
+ " `PREFECT_API_URL` and `PREFECT_API_KEY` environment variables."
173
+ )
174
+ else:
175
+ raise
176
+
177
+ @sync_compatible
178
+ async def stop(self):
179
+ """Stops the task worker's polling cycle."""
180
+ if not self.started:
181
+ raise RuntimeError(
182
+ "Task worker has not yet started. Please start the task worker by"
183
+ " calling .start()"
184
+ )
185
+
186
+ self._started_at = None
187
+ self.stopping = True
188
+
189
+ raise StopTaskWorker
190
+
191
+ async def _acquire_token(self, task_run_id: UUID) -> bool:
192
+ try:
193
+ if self._limiter:
194
+ await self._limiter.acquire_on_behalf_of(task_run_id)
195
+ except RuntimeError:
196
+ logger.debug(f"Token already acquired for task run: {task_run_id!r}")
197
+ return False
198
+
199
+ return True
200
+
201
+ def _release_token(self, task_run_id: UUID) -> bool:
202
+ try:
203
+ if self._limiter:
204
+ self._limiter.release_on_behalf_of(task_run_id)
205
+ except RuntimeError:
206
+ logger.debug(f"No token to release for task run: {task_run_id!r}")
207
+ return False
208
+
209
+ return True
210
+
211
+ async def _subscribe_to_task_scheduling(self):
212
+ base_url = PREFECT_API_URL.value()
213
+ if base_url is None:
214
+ raise ValueError(
215
+ "`PREFECT_API_URL` must be set to use the task worker. "
216
+ "Task workers are not compatible with the ephemeral API."
217
+ )
218
+ task_keys_repr = " | ".join(
219
+ task_key.split(".")[-1].split("-")[0] for task_key in sorted(self.task_keys)
220
+ )
221
+ logger.info(f"Subscribing to runs of task(s): {task_keys_repr}")
222
+ async for task_run in Subscription(
223
+ model=TaskRun,
224
+ path="/task_runs/subscriptions/scheduled",
225
+ keys=self.task_keys,
226
+ client_id=self.client_id,
227
+ base_url=base_url,
228
+ ):
229
+ logger.info(f"Received task run: {task_run.id} - {task_run.name}")
230
+
231
+ token_acquired = await self._acquire_token(task_run.id)
232
+ if token_acquired:
233
+ self._runs_task_group.start_soon(
234
+ self._safe_submit_scheduled_task_run, task_run
235
+ )
236
+
237
+ async def _safe_submit_scheduled_task_run(self, task_run: TaskRun):
238
+ self.in_flight_task_runs[task_run.task_key][task_run.id] = pendulum.now()
239
+ try:
240
+ await self._submit_scheduled_task_run(task_run)
241
+ except BaseException as exc:
242
+ logger.exception(
243
+ f"Failed to submit task run {task_run.id!r}",
244
+ exc_info=exc,
245
+ )
246
+ finally:
247
+ self.in_flight_task_runs[task_run.task_key].pop(task_run.id, None)
248
+ self.finished_task_runs[task_run.task_key] += 1
249
+ self._release_token(task_run.id)
250
+
251
+ async def _submit_scheduled_task_run(self, task_run: TaskRun):
252
+ logger.debug(
253
+ f"Found task run: {task_run.name!r} in state: {task_run.state.name!r}"
254
+ )
255
+
256
+ task = next((t for t in self.tasks if t.task_key == task_run.task_key), None)
257
+
258
+ if not task:
259
+ if PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS:
260
+ logger.warning(
261
+ f"Task {task_run.name!r} not found in task worker registry."
262
+ )
263
+ await self._client._client.delete(f"/task_runs/{task_run.id}") # type: ignore
264
+
265
+ return
266
+
267
+ # The ID of the parameters for this run are stored in the Scheduled state's
268
+ # state_details. If there is no parameters_id, then the task was created
269
+ # without parameters.
270
+ parameters = {}
271
+ wait_for = []
272
+ run_context = None
273
+ if should_try_to_read_parameters(task, task_run):
274
+ parameters_id = task_run.state.state_details.task_parameters_id
275
+ task.persist_result = True
276
+ store = await ResultStore(
277
+ result_storage=await get_or_create_default_task_scheduling_storage()
278
+ ).update_for_task(task)
279
+ try:
280
+ run_data = await store.read_parameters(parameters_id)
281
+ parameters = run_data.get("parameters", {})
282
+ wait_for = run_data.get("wait_for", [])
283
+ run_context = run_data.get("context", None)
284
+ except Exception as exc:
285
+ logger.exception(
286
+ f"Failed to read parameters for task run {task_run.id!r}",
287
+ exc_info=exc,
288
+ )
289
+ if PREFECT_TASK_SCHEDULING_DELETE_FAILED_SUBMISSIONS.value():
290
+ logger.info(
291
+ f"Deleting task run {task_run.id!r} because it failed to submit"
292
+ )
293
+ await self._client._client.delete(f"/task_runs/{task_run.id}")
294
+ return
295
+
296
+ initial_state = task_run.state
297
+ new_state = Pending()
298
+ new_state.state_details.deferred = True
299
+ new_state.state_details.task_run_id = task_run.id
300
+ new_state.state_details.flow_run_id = task_run.flow_run_id
301
+ state = new_state
302
+ task_run.state = state
303
+
304
+ emit_task_run_state_change_event(
305
+ task_run=task_run,
306
+ initial_state=initial_state,
307
+ validated_state=state,
308
+ )
309
+
310
+ if task_run_url := url_for(task_run):
311
+ logger.info(
312
+ f"Submitting task run {task_run.name!r} to engine. View in the UI: {task_run_url}"
313
+ )
314
+
315
+ if task.isasync:
316
+ await run_task_async(
317
+ task=task,
318
+ task_run_id=task_run.id,
319
+ task_run=task_run,
320
+ parameters=parameters,
321
+ wait_for=wait_for,
322
+ return_type="state",
323
+ context=run_context,
324
+ )
325
+ else:
326
+ context = copy_context()
327
+ future = self._executor.submit(
328
+ context.run,
329
+ run_task_sync,
330
+ task=task,
331
+ task_run_id=task_run.id,
332
+ task_run=task_run,
333
+ parameters=parameters,
334
+ wait_for=wait_for,
335
+ return_type="state",
336
+ context=run_context,
337
+ )
338
+ await asyncio.wrap_future(future)
339
+
340
+ async def execute_task_run(self, task_run: TaskRun):
341
+ """Execute a task run in the task worker."""
342
+ async with self if not self.started else asyncnullcontext():
343
+ token_acquired = await self._acquire_token(task_run.id)
344
+ if token_acquired:
345
+ await self._safe_submit_scheduled_task_run(task_run)
346
+
347
+ async def __aenter__(self):
348
+ logger.debug("Starting task worker...")
349
+
350
+ if self._client._closed:
351
+ self._client = get_client()
352
+
353
+ await self._exit_stack.enter_async_context(self._client)
354
+ await self._exit_stack.enter_async_context(self._runs_task_group)
355
+ self._exit_stack.enter_context(self._executor)
356
+
357
+ self._started_at = pendulum.now()
358
+ return self
359
+
360
+ async def __aexit__(self, *exc_info):
361
+ logger.debug("Stopping task worker...")
362
+ self._started_at = None
363
+ await self._exit_stack.__aexit__(*exc_info)
364
+
365
+
366
+ def create_status_server(task_worker: TaskWorker) -> FastAPI:
367
+ status_app = FastAPI()
368
+
369
+ @status_app.get("/status")
370
+ def status():
371
+ return {
372
+ "client_id": task_worker.client_id,
373
+ "started_at": task_worker.started_at.isoformat(),
374
+ "stopping": task_worker.stopping,
375
+ "limit": task_worker.limit,
376
+ "current": task_worker.current_tasks,
377
+ "available": task_worker.available_tasks,
378
+ "tasks": sorted(task_worker.task_keys),
379
+ "finished": task_worker.finished_task_runs,
380
+ "in_flight": {
381
+ key: {str(run): start.isoformat() for run, start in tasks.items()}
382
+ for key, tasks in task_worker.in_flight_task_runs.items()
383
+ },
384
+ }
385
+
386
+ return status_app
387
+
388
+
389
+ @sync_compatible
390
+ async def serve(
391
+ *tasks: Task, limit: Optional[int] = 10, status_server_port: Optional[int] = None
392
+ ):
393
+ """Serve the provided tasks so that their runs may be submitted to and executed.
394
+ in the engine. Tasks do not need to be within a flow run context to be submitted.
395
+ You must `.submit` the same task object that you pass to `serve`.
396
+
397
+ Args:
398
+ - tasks: A list of tasks to serve. When a scheduled task run is found for a
399
+ given task, the task run will be submitted to the engine for execution.
400
+ - limit: The maximum number of tasks that can be run concurrently. Defaults to 10.
401
+ Pass `None` to remove the limit.
402
+ - status_server_port: An optional port on which to start an HTTP server
403
+ exposing status information about the task worker. If not provided, no
404
+ status server will run.
405
+
406
+ Example:
407
+ ```python
408
+ from prefect import task
409
+ from prefect.task_worker import serve
410
+
411
+ @task(log_prints=True)
412
+ def say(message: str):
413
+ print(message)
414
+
415
+ @task(log_prints=True)
416
+ def yell(message: str):
417
+ print(message.upper())
418
+
419
+ # starts a long-lived process that listens for scheduled runs of these tasks
420
+ if __name__ == "__main__":
421
+ serve(say, yell)
422
+ ```
423
+ """
424
+ task_worker = TaskWorker(*tasks, limit=limit)
425
+
426
+ status_server_task = None
427
+ if status_server_port is not None:
428
+ server = uvicorn.Server(
429
+ uvicorn.Config(
430
+ app=create_status_server(task_worker),
431
+ host="127.0.0.1",
432
+ port=status_server_port,
433
+ access_log=False,
434
+ log_level="warning",
435
+ )
436
+ )
437
+ loop = asyncio.get_event_loop()
438
+ status_server_task = loop.create_task(server.serve())
439
+
440
+ try:
441
+ await task_worker.start()
442
+
443
+ except BaseExceptionGroup as exc: # novermin
444
+ exceptions = exc.exceptions
445
+ n_exceptions = len(exceptions)
446
+ logger.error(
447
+ f"Task worker stopped with {n_exceptions} exception{'s' if n_exceptions != 1 else ''}:"
448
+ f"\n" + "\n".join(str(e) for e in exceptions)
449
+ )
450
+
451
+ except StopTaskWorker:
452
+ logger.info("Task worker stopped.")
453
+
454
+ except (asyncio.CancelledError, KeyboardInterrupt):
455
+ logger.info("Task worker interrupted, stopping...")
456
+
457
+ finally:
458
+ if status_server_task:
459
+ status_server_task.cancel()
460
+ try:
461
+ await status_server_task
462
+ except asyncio.CancelledError:
463
+ pass