prefect-client 3.0.0rc3__py3-none-any.whl → 3.0.0rc4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- prefect/__init__.py +0 -1
- prefect/client/subscriptions.py +3 -3
- prefect/flow_engine.py +82 -2
- prefect/flows.py +12 -2
- prefect/futures.py +9 -1
- prefect/results.py +33 -31
- prefect/settings.py +1 -1
- prefect/task_engine.py +5 -6
- prefect/task_runs.py +23 -9
- prefect/task_worker.py +128 -19
- prefect/tasks.py +20 -14
- prefect/transactions.py +6 -8
- prefect/types/__init__.py +10 -3
- prefect/utilities/collections.py +120 -57
- prefect/utilities/urls.py +5 -5
- {prefect_client-3.0.0rc3.dist-info → prefect_client-3.0.0rc4.dist-info}/METADATA +2 -2
- {prefect_client-3.0.0rc3.dist-info → prefect_client-3.0.0rc4.dist-info}/RECORD +20 -21
- prefect/blocks/kubernetes.py +0 -115
- {prefect_client-3.0.0rc3.dist-info → prefect_client-3.0.0rc4.dist-info}/LICENSE +0 -0
- {prefect_client-3.0.0rc3.dist-info → prefect_client-3.0.0rc4.dist-info}/WHEEL +0 -0
- {prefect_client-3.0.0rc3.dist-info → prefect_client-3.0.0rc4.dist-info}/top_level.txt +0 -0
prefect/__init__.py
CHANGED
prefect/client/subscriptions.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
import asyncio
|
2
|
-
from typing import Any, Dict, Generic,
|
2
|
+
from typing import Any, Dict, Generic, Iterable, Optional, Type, TypeVar
|
3
3
|
|
4
4
|
import orjson
|
5
5
|
import websockets
|
@@ -21,7 +21,7 @@ class Subscription(Generic[S]):
|
|
21
21
|
self,
|
22
22
|
model: Type[S],
|
23
23
|
path: str,
|
24
|
-
keys:
|
24
|
+
keys: Iterable[str],
|
25
25
|
client_id: Optional[str] = None,
|
26
26
|
base_url: Optional[str] = None,
|
27
27
|
):
|
@@ -30,7 +30,7 @@ class Subscription(Generic[S]):
|
|
30
30
|
base_url = base_url.replace("http", "ws", 1)
|
31
31
|
self.subscription_url = f"{base_url}{path}"
|
32
32
|
|
33
|
-
self.keys = keys
|
33
|
+
self.keys = list(keys)
|
34
34
|
|
35
35
|
self._connect = websockets.connect(
|
36
36
|
self.subscription_url,
|
prefect/flow_engine.py
CHANGED
@@ -6,6 +6,7 @@ from contextlib import ExitStack, contextmanager
|
|
6
6
|
from dataclasses import dataclass, field
|
7
7
|
from typing import (
|
8
8
|
Any,
|
9
|
+
AsyncGenerator,
|
9
10
|
Callable,
|
10
11
|
Coroutine,
|
11
12
|
Dict,
|
@@ -50,12 +51,13 @@ from prefect.states import (
|
|
50
51
|
return_value_to_state,
|
51
52
|
)
|
52
53
|
from prefect.utilities.asyncutils import run_coro_as_sync
|
53
|
-
from prefect.utilities.callables import call_with_parameters
|
54
|
+
from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs
|
54
55
|
from prefect.utilities.collections import visit_collection
|
55
56
|
from prefect.utilities.engine import (
|
56
57
|
_get_hook_name,
|
57
58
|
_resolve_custom_flow_run_name,
|
58
59
|
capture_sigterm,
|
60
|
+
link_state_to_result,
|
59
61
|
propose_state_sync,
|
60
62
|
resolve_to_final_result,
|
61
63
|
)
|
@@ -632,6 +634,80 @@ async def run_flow_async(
|
|
632
634
|
return engine.state if return_type == "state" else engine.result()
|
633
635
|
|
634
636
|
|
637
|
+
def run_generator_flow_sync(
|
638
|
+
flow: Flow[P, R],
|
639
|
+
flow_run: Optional[FlowRun] = None,
|
640
|
+
parameters: Optional[Dict[str, Any]] = None,
|
641
|
+
wait_for: Optional[Iterable[PrefectFuture]] = None,
|
642
|
+
return_type: Literal["state", "result"] = "result",
|
643
|
+
) -> Generator[R, None, None]:
|
644
|
+
if return_type != "result":
|
645
|
+
raise ValueError("The return_type for a generator flow must be 'result'")
|
646
|
+
|
647
|
+
engine = FlowRunEngine[P, R](
|
648
|
+
flow=flow, parameters=parameters, flow_run=flow_run, wait_for=wait_for
|
649
|
+
)
|
650
|
+
|
651
|
+
with engine.start():
|
652
|
+
while engine.is_running():
|
653
|
+
with engine.run_context():
|
654
|
+
call_args, call_kwargs = parameters_to_args_kwargs(
|
655
|
+
flow.fn, engine.parameters or {}
|
656
|
+
)
|
657
|
+
gen = flow.fn(*call_args, **call_kwargs)
|
658
|
+
try:
|
659
|
+
while True:
|
660
|
+
gen_result = next(gen)
|
661
|
+
# link the current state to the result for dependency tracking
|
662
|
+
link_state_to_result(engine.state, gen_result)
|
663
|
+
yield gen_result
|
664
|
+
except StopIteration as exc:
|
665
|
+
engine.handle_success(exc.value)
|
666
|
+
except GeneratorExit as exc:
|
667
|
+
engine.handle_success(None)
|
668
|
+
gen.throw(exc)
|
669
|
+
|
670
|
+
return engine.result()
|
671
|
+
|
672
|
+
|
673
|
+
async def run_generator_flow_async(
|
674
|
+
flow: Flow[P, R],
|
675
|
+
flow_run: Optional[FlowRun] = None,
|
676
|
+
parameters: Optional[Dict[str, Any]] = None,
|
677
|
+
wait_for: Optional[Iterable[PrefectFuture]] = None,
|
678
|
+
return_type: Literal["state", "result"] = "result",
|
679
|
+
) -> AsyncGenerator[R, None]:
|
680
|
+
if return_type != "result":
|
681
|
+
raise ValueError("The return_type for a generator flow must be 'result'")
|
682
|
+
|
683
|
+
engine = FlowRunEngine[P, R](
|
684
|
+
flow=flow, parameters=parameters, flow_run=flow_run, wait_for=wait_for
|
685
|
+
)
|
686
|
+
|
687
|
+
with engine.start():
|
688
|
+
while engine.is_running():
|
689
|
+
with engine.run_context():
|
690
|
+
call_args, call_kwargs = parameters_to_args_kwargs(
|
691
|
+
flow.fn, engine.parameters or {}
|
692
|
+
)
|
693
|
+
gen = flow.fn(*call_args, **call_kwargs)
|
694
|
+
try:
|
695
|
+
while True:
|
696
|
+
# can't use anext in Python < 3.10
|
697
|
+
gen_result = await gen.__anext__()
|
698
|
+
# link the current state to the result for dependency tracking
|
699
|
+
link_state_to_result(engine.state, gen_result)
|
700
|
+
yield gen_result
|
701
|
+
except (StopAsyncIteration, GeneratorExit) as exc:
|
702
|
+
engine.handle_success(None)
|
703
|
+
if isinstance(exc, GeneratorExit):
|
704
|
+
gen.throw(exc)
|
705
|
+
|
706
|
+
# async generators can't return, but we can raise failures here
|
707
|
+
if engine.state.is_failed():
|
708
|
+
engine.result()
|
709
|
+
|
710
|
+
|
635
711
|
def run_flow(
|
636
712
|
flow: Flow[P, R],
|
637
713
|
flow_run: Optional[FlowRun] = None,
|
@@ -646,7 +722,11 @@ def run_flow(
|
|
646
722
|
wait_for=wait_for,
|
647
723
|
return_type=return_type,
|
648
724
|
)
|
649
|
-
if flow.isasync:
|
725
|
+
if flow.isasync and flow.isgenerator:
|
726
|
+
return run_generator_flow_async(**kwargs)
|
727
|
+
elif flow.isgenerator:
|
728
|
+
return run_generator_flow_sync(**kwargs)
|
729
|
+
elif flow.isasync:
|
650
730
|
return run_flow_async(**kwargs)
|
651
731
|
else:
|
652
732
|
return run_flow_sync(**kwargs)
|
prefect/flows.py
CHANGED
@@ -89,7 +89,6 @@ from prefect.task_runners import TaskRunner, ThreadPoolTaskRunner
|
|
89
89
|
from prefect.types import BANNED_CHARACTERS, WITHOUT_BANNED_CHARACTERS
|
90
90
|
from prefect.utilities.annotations import NotSet
|
91
91
|
from prefect.utilities.asyncutils import (
|
92
|
-
is_async_fn,
|
93
92
|
run_sync_in_worker_thread,
|
94
93
|
sync_compatible,
|
95
94
|
)
|
@@ -289,7 +288,18 @@ class Flow(Generic[P, R]):
|
|
289
288
|
self.description = description or inspect.getdoc(fn)
|
290
289
|
update_wrapper(self, fn)
|
291
290
|
self.fn = fn
|
292
|
-
|
291
|
+
|
292
|
+
# the flow is considered async if its function is async or an async
|
293
|
+
# generator
|
294
|
+
self.isasync = inspect.iscoroutinefunction(
|
295
|
+
self.fn
|
296
|
+
) or inspect.isasyncgenfunction(self.fn)
|
297
|
+
|
298
|
+
# the flow is considered a generator if its function is a generator or
|
299
|
+
# an async generator
|
300
|
+
self.isgenerator = inspect.isgeneratorfunction(
|
301
|
+
self.fn
|
302
|
+
) or inspect.isasyncgenfunction(self.fn)
|
293
303
|
|
294
304
|
raise_for_reserved_arguments(self.fn, ["return_state", "wait_for"])
|
295
305
|
|
prefect/futures.py
CHANGED
@@ -56,7 +56,7 @@ class PrefectFuture(abc.ABC):
|
|
56
56
|
def wait(self, timeout: Optional[float] = None) -> None:
|
57
57
|
...
|
58
58
|
"""
|
59
|
-
Wait for the task run to complete.
|
59
|
+
Wait for the task run to complete.
|
60
60
|
|
61
61
|
If the task run has already completed, this method will return immediately.
|
62
62
|
|
@@ -163,6 +163,10 @@ class PrefectDistributedFuture(PrefectFuture):
|
|
163
163
|
)
|
164
164
|
return
|
165
165
|
|
166
|
+
# Ask for the instance of TaskRunWaiter _now_ so that it's already running and
|
167
|
+
# can catch the completion event if it happens before we start listening for it.
|
168
|
+
TaskRunWaiter.instance()
|
169
|
+
|
166
170
|
# Read task run to see if it is still running
|
167
171
|
async with get_client() as client:
|
168
172
|
task_run = await client.read_task_run(task_run_id=self._task_run_id)
|
@@ -245,6 +249,10 @@ def resolve_futures_to_states(
|
|
245
249
|
context={},
|
246
250
|
)
|
247
251
|
|
252
|
+
# if no futures were found, return the original expression
|
253
|
+
if not futures:
|
254
|
+
return expr
|
255
|
+
|
248
256
|
# Get final states for each future
|
249
257
|
states = []
|
250
258
|
for future in futures:
|
prefect/results.py
CHANGED
@@ -38,7 +38,6 @@ from prefect.settings import (
|
|
38
38
|
PREFECT_RESULTS_DEFAULT_SERIALIZER,
|
39
39
|
PREFECT_RESULTS_PERSIST_BY_DEFAULT,
|
40
40
|
PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK,
|
41
|
-
default_result_storage_block_name,
|
42
41
|
)
|
43
42
|
from prefect.utilities.annotations import NotSet
|
44
43
|
from prefect.utilities.asyncutils import sync_compatible
|
@@ -62,35 +61,15 @@ logger = get_logger("results")
|
|
62
61
|
P = ParamSpec("P")
|
63
62
|
R = TypeVar("R")
|
64
63
|
|
65
|
-
|
66
|
-
@sync_compatible
|
67
|
-
async def get_default_result_storage() -> ResultStorage:
|
68
|
-
"""
|
69
|
-
Generate a default file system for result storage.
|
70
|
-
"""
|
71
|
-
try:
|
72
|
-
return await Block.load(PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value())
|
73
|
-
except ValueError as e:
|
74
|
-
if "Unable to find" not in str(e):
|
75
|
-
raise e
|
76
|
-
elif (
|
77
|
-
PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value()
|
78
|
-
== default_result_storage_block_name()
|
79
|
-
):
|
80
|
-
return LocalFileSystem(basepath=PREFECT_LOCAL_STORAGE_PATH.value())
|
81
|
-
else:
|
82
|
-
raise
|
64
|
+
_default_storages: Dict[Tuple[str, str], WritableFileSystem] = {}
|
83
65
|
|
84
66
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
|
67
|
+
async def _get_or_create_default_storage(block_document_slug: str) -> ResultStorage:
|
89
68
|
"""
|
90
|
-
Generate a default file system for
|
69
|
+
Generate a default file system for storage.
|
91
70
|
"""
|
92
71
|
default_storage_name, storage_path = cache_key = (
|
93
|
-
|
72
|
+
block_document_slug,
|
94
73
|
PREFECT_LOCAL_STORAGE_PATH.value(),
|
95
74
|
)
|
96
75
|
|
@@ -105,8 +84,8 @@ async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
|
|
105
84
|
if block_type_slug == "local-file-system":
|
106
85
|
block = LocalFileSystem(basepath=storage_path)
|
107
86
|
else:
|
108
|
-
raise
|
109
|
-
"The default
|
87
|
+
raise ValueError(
|
88
|
+
"The default storage block does not exist, but it is of type "
|
110
89
|
f"'{block_type_slug}' which cannot be created implicitly. Please create "
|
111
90
|
"the block manually."
|
112
91
|
)
|
@@ -123,13 +102,32 @@ async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
|
|
123
102
|
return block
|
124
103
|
|
125
104
|
try:
|
126
|
-
return
|
105
|
+
return _default_storages[cache_key]
|
127
106
|
except KeyError:
|
128
107
|
storage = await get_storage()
|
129
|
-
|
108
|
+
_default_storages[cache_key] = storage
|
130
109
|
return storage
|
131
110
|
|
132
111
|
|
112
|
+
@sync_compatible
|
113
|
+
async def get_or_create_default_result_storage() -> ResultStorage:
|
114
|
+
"""
|
115
|
+
Generate a default file system for result storage.
|
116
|
+
"""
|
117
|
+
return await _get_or_create_default_storage(
|
118
|
+
PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value()
|
119
|
+
)
|
120
|
+
|
121
|
+
|
122
|
+
async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
|
123
|
+
"""
|
124
|
+
Generate a default file system for background task parameter/result storage.
|
125
|
+
"""
|
126
|
+
return await _get_or_create_default_storage(
|
127
|
+
PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK.value()
|
128
|
+
)
|
129
|
+
|
130
|
+
|
133
131
|
def get_default_result_serializer() -> ResultSerializer:
|
134
132
|
"""
|
135
133
|
Generate a default file system for result storage.
|
@@ -210,7 +208,9 @@ class ResultFactory(BaseModel):
|
|
210
208
|
kwargs.pop(key)
|
211
209
|
|
212
210
|
# Apply defaults
|
213
|
-
kwargs.setdefault(
|
211
|
+
kwargs.setdefault(
|
212
|
+
"result_storage", await get_or_create_default_result_storage()
|
213
|
+
)
|
214
214
|
kwargs.setdefault("result_serializer", get_default_result_serializer())
|
215
215
|
kwargs.setdefault("persist_result", get_default_persist_setting())
|
216
216
|
kwargs.setdefault("cache_result_in_memory", True)
|
@@ -280,7 +280,9 @@ class ResultFactory(BaseModel):
|
|
280
280
|
"""
|
281
281
|
Create a new result factory for a task.
|
282
282
|
"""
|
283
|
-
return await cls._from_task(
|
283
|
+
return await cls._from_task(
|
284
|
+
task, get_or_create_default_result_storage, client=client
|
285
|
+
)
|
284
286
|
|
285
287
|
@classmethod
|
286
288
|
@inject_client
|
prefect/settings.py
CHANGED
@@ -1541,7 +1541,7 @@ The maximum number of retries to queue for submission.
|
|
1541
1541
|
|
1542
1542
|
PREFECT_TASK_SCHEDULING_PENDING_TASK_TIMEOUT = Setting(
|
1543
1543
|
timedelta,
|
1544
|
-
default=timedelta(
|
1544
|
+
default=timedelta(0),
|
1545
1545
|
)
|
1546
1546
|
"""
|
1547
1547
|
How long before a PENDING task are made available to another task worker. In practice,
|
prefect/task_engine.py
CHANGED
@@ -417,9 +417,7 @@ class TaskRunEngine(Generic[P, R]):
|
|
417
417
|
log_prints=log_prints,
|
418
418
|
task_run=self.task_run,
|
419
419
|
parameters=self.parameters,
|
420
|
-
result_factory=run_coro_as_sync(
|
421
|
-
ResultFactory.from_autonomous_task(self.task)
|
422
|
-
), # type: ignore
|
420
|
+
result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore
|
423
421
|
client=client,
|
424
422
|
)
|
425
423
|
)
|
@@ -467,9 +465,6 @@ class TaskRunEngine(Generic[P, R]):
|
|
467
465
|
extra_task_inputs=dependencies,
|
468
466
|
)
|
469
467
|
)
|
470
|
-
self.logger.info(
|
471
|
-
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
|
472
|
-
)
|
473
468
|
# Emit an event to capture that the task run was in the `PENDING` state.
|
474
469
|
self._last_event = emit_task_run_state_change_event(
|
475
470
|
task_run=self.task_run,
|
@@ -478,6 +473,10 @@ class TaskRunEngine(Generic[P, R]):
|
|
478
473
|
)
|
479
474
|
|
480
475
|
with self.setup_run_context():
|
476
|
+
# setup_run_context might update the task run name, so log creation here
|
477
|
+
self.logger.info(
|
478
|
+
f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
|
479
|
+
)
|
481
480
|
yield self
|
482
481
|
|
483
482
|
except Exception:
|
prefect/task_runs.py
CHANGED
@@ -92,13 +92,18 @@ class TaskRunWaiter:
|
|
92
92
|
raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
|
93
93
|
|
94
94
|
self._loop = loop_thread._loop
|
95
|
-
|
95
|
+
|
96
|
+
consumer_started = asyncio.Event()
|
97
|
+
self._consumer_task = self._loop.create_task(
|
98
|
+
self._consume_events(consumer_started)
|
99
|
+
)
|
100
|
+
asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop)
|
96
101
|
|
97
102
|
loop_thread.add_shutdown_call(create_call(self.stop))
|
98
103
|
atexit.register(self.stop)
|
99
104
|
self._started = True
|
100
105
|
|
101
|
-
async def _consume_events(self):
|
106
|
+
async def _consume_events(self, consumer_started: asyncio.Event):
|
102
107
|
async with get_events_subscriber(
|
103
108
|
filter=EventFilter(
|
104
109
|
event=EventNameFilter(
|
@@ -109,6 +114,7 @@ class TaskRunWaiter:
|
|
109
114
|
)
|
110
115
|
)
|
111
116
|
) as subscriber:
|
117
|
+
consumer_started.set()
|
112
118
|
async for event in subscriber:
|
113
119
|
try:
|
114
120
|
self.logger.debug(
|
@@ -119,6 +125,7 @@ class TaskRunWaiter:
|
|
119
125
|
"prefect.task-run.", ""
|
120
126
|
)
|
121
127
|
)
|
128
|
+
|
122
129
|
with self._observed_completed_task_runs_lock:
|
123
130
|
# Cache the task run ID for a short period of time to avoid
|
124
131
|
# unnecessary waits
|
@@ -172,14 +179,21 @@ class TaskRunWaiter:
|
|
172
179
|
# when the event is received
|
173
180
|
instance._completion_events[task_run_id] = finished_event
|
174
181
|
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
182
|
+
try:
|
183
|
+
# Now check one more time whether the task run arrived before we start to
|
184
|
+
# wait on it, in case it came in while we were setting up the event above.
|
185
|
+
with instance._observed_completed_task_runs_lock:
|
186
|
+
if task_run_id in instance._observed_completed_task_runs:
|
187
|
+
return
|
179
188
|
|
180
|
-
|
181
|
-
|
182
|
-
|
189
|
+
with anyio.move_on_after(delay=timeout):
|
190
|
+
await from_async.wait_for_call_in_loop_thread(
|
191
|
+
create_call(finished_event.wait)
|
192
|
+
)
|
193
|
+
finally:
|
194
|
+
with instance._completion_events_lock:
|
195
|
+
# Remove the event from the cache after it has been waited on
|
196
|
+
instance._completion_events.pop(task_run_id, None)
|
183
197
|
|
184
198
|
@classmethod
|
185
199
|
def instance(cls):
|
prefect/task_worker.py
CHANGED
@@ -8,10 +8,14 @@ from concurrent.futures import ThreadPoolExecutor
|
|
8
8
|
from contextlib import AsyncExitStack
|
9
9
|
from contextvars import copy_context
|
10
10
|
from typing import List, Optional
|
11
|
+
from uuid import UUID
|
11
12
|
|
12
13
|
import anyio
|
13
14
|
import anyio.abc
|
15
|
+
import pendulum
|
16
|
+
import uvicorn
|
14
17
|
from exceptiongroup import BaseExceptionGroup # novermin
|
18
|
+
from fastapi import FastAPI
|
15
19
|
from websockets.exceptions import InvalidStatusCode
|
16
20
|
|
17
21
|
from prefect import Task
|
@@ -73,8 +77,9 @@ class TaskWorker:
|
|
73
77
|
limit: Optional[int] = 10,
|
74
78
|
):
|
75
79
|
self.tasks: List[Task] = list(tasks)
|
80
|
+
self.task_keys = set(t.task_key for t in tasks if isinstance(t, Task))
|
76
81
|
|
77
|
-
self.
|
82
|
+
self._started_at: Optional[pendulum.DateTime] = None
|
78
83
|
self.stopping: bool = False
|
79
84
|
|
80
85
|
self._client = get_client()
|
@@ -89,10 +94,41 @@ class TaskWorker:
|
|
89
94
|
self._executor = ThreadPoolExecutor(max_workers=limit if limit else None)
|
90
95
|
self._limiter = anyio.CapacityLimiter(limit) if limit else None
|
91
96
|
|
97
|
+
self.in_flight_task_runs: dict[str, dict[UUID, pendulum.DateTime]] = {
|
98
|
+
task_key: {} for task_key in self.task_keys
|
99
|
+
}
|
100
|
+
self.finished_task_runs: dict[str, int] = {
|
101
|
+
task_key: 0 for task_key in self.task_keys
|
102
|
+
}
|
103
|
+
|
92
104
|
@property
|
93
|
-
def
|
105
|
+
def client_id(self) -> str:
|
94
106
|
return f"{socket.gethostname()}-{os.getpid()}"
|
95
107
|
|
108
|
+
@property
|
109
|
+
def started_at(self) -> Optional[pendulum.DateTime]:
|
110
|
+
return self._started_at
|
111
|
+
|
112
|
+
@property
|
113
|
+
def started(self) -> bool:
|
114
|
+
return self._started_at is not None
|
115
|
+
|
116
|
+
@property
|
117
|
+
def limit(self) -> Optional[int]:
|
118
|
+
return int(self._limiter.total_tokens) if self._limiter else None
|
119
|
+
|
120
|
+
@property
|
121
|
+
def current_tasks(self) -> Optional[int]:
|
122
|
+
return (
|
123
|
+
int(self._limiter.borrowed_tokens)
|
124
|
+
if self._limiter
|
125
|
+
else sum(len(runs) for runs in self.in_flight_task_runs.values())
|
126
|
+
)
|
127
|
+
|
128
|
+
@property
|
129
|
+
def available_tasks(self) -> Optional[int]:
|
130
|
+
return int(self._limiter.available_tokens) if self._limiter else None
|
131
|
+
|
96
132
|
def handle_sigterm(self, signum, frame):
|
97
133
|
"""
|
98
134
|
Shuts down the task worker when a SIGTERM is received.
|
@@ -133,11 +169,31 @@ class TaskWorker:
|
|
133
169
|
" calling .start()"
|
134
170
|
)
|
135
171
|
|
136
|
-
self.
|
172
|
+
self._started_at = None
|
137
173
|
self.stopping = True
|
138
174
|
|
139
175
|
raise StopTaskWorker
|
140
176
|
|
177
|
+
async def _acquire_token(self, task_run_id: UUID) -> bool:
|
178
|
+
try:
|
179
|
+
if self._limiter:
|
180
|
+
await self._limiter.acquire_on_behalf_of(task_run_id)
|
181
|
+
except RuntimeError:
|
182
|
+
logger.debug(f"Token already acquired for task run: {task_run_id!r}")
|
183
|
+
return False
|
184
|
+
|
185
|
+
return True
|
186
|
+
|
187
|
+
def _release_token(self, task_run_id: UUID) -> bool:
|
188
|
+
try:
|
189
|
+
if self._limiter:
|
190
|
+
self._limiter.release_on_behalf_of(task_run_id)
|
191
|
+
except RuntimeError:
|
192
|
+
logger.debug(f"No token to release for task run: {task_run_id!r}")
|
193
|
+
return False
|
194
|
+
|
195
|
+
return True
|
196
|
+
|
141
197
|
async def _subscribe_to_task_scheduling(self):
|
142
198
|
base_url = PREFECT_API_URL.value()
|
143
199
|
if base_url is None:
|
@@ -146,24 +202,26 @@ class TaskWorker:
|
|
146
202
|
"Task workers are not compatible with the ephemeral API."
|
147
203
|
)
|
148
204
|
task_keys_repr = " | ".join(
|
149
|
-
|
205
|
+
task_key.split(".")[-1].split("-")[0] for task_key in sorted(self.task_keys)
|
150
206
|
)
|
151
207
|
logger.info(f"Subscribing to runs of task(s): {task_keys_repr}")
|
152
208
|
async for task_run in Subscription(
|
153
209
|
model=TaskRun,
|
154
210
|
path="/task_runs/subscriptions/scheduled",
|
155
|
-
keys=
|
156
|
-
client_id=self.
|
211
|
+
keys=self.task_keys,
|
212
|
+
client_id=self.client_id,
|
157
213
|
base_url=base_url,
|
158
214
|
):
|
159
215
|
logger.info(f"Received task run: {task_run.id} - {task_run.name}")
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
self.
|
164
|
-
|
216
|
+
|
217
|
+
token_acquired = await self._acquire_token(task_run.id)
|
218
|
+
if token_acquired:
|
219
|
+
self._runs_task_group.start_soon(
|
220
|
+
self._safe_submit_scheduled_task_run, task_run
|
221
|
+
)
|
165
222
|
|
166
223
|
async def _safe_submit_scheduled_task_run(self, task_run: TaskRun):
|
224
|
+
self.in_flight_task_runs[task_run.task_key][task_run.id] = pendulum.now()
|
167
225
|
try:
|
168
226
|
await self._submit_scheduled_task_run(task_run)
|
169
227
|
except BaseException as exc:
|
@@ -172,8 +230,9 @@ class TaskWorker:
|
|
172
230
|
exc_info=exc,
|
173
231
|
)
|
174
232
|
finally:
|
175
|
-
|
176
|
-
|
233
|
+
self.in_flight_task_runs[task_run.task_key].pop(task_run.id, None)
|
234
|
+
self.finished_task_runs[task_run.task_key] += 1
|
235
|
+
self._release_token(task_run.id)
|
177
236
|
|
178
237
|
async def _submit_scheduled_task_run(self, task_run: TaskRun):
|
179
238
|
logger.debug(
|
@@ -284,9 +343,9 @@ class TaskWorker:
|
|
284
343
|
async def execute_task_run(self, task_run: TaskRun):
|
285
344
|
"""Execute a task run in the task worker."""
|
286
345
|
async with self if not self.started else asyncnullcontext():
|
287
|
-
|
288
|
-
|
289
|
-
|
346
|
+
token_acquired = await self._acquire_token(task_run.id)
|
347
|
+
if token_acquired:
|
348
|
+
await self._safe_submit_scheduled_task_run(task_run)
|
290
349
|
|
291
350
|
async def __aenter__(self):
|
292
351
|
logger.debug("Starting task worker...")
|
@@ -298,17 +357,42 @@ class TaskWorker:
|
|
298
357
|
await self._exit_stack.enter_async_context(self._runs_task_group)
|
299
358
|
self._exit_stack.enter_context(self._executor)
|
300
359
|
|
301
|
-
self.
|
360
|
+
self._started_at = pendulum.now()
|
302
361
|
return self
|
303
362
|
|
304
363
|
async def __aexit__(self, *exc_info):
|
305
364
|
logger.debug("Stopping task worker...")
|
306
|
-
self.
|
365
|
+
self._started_at = None
|
307
366
|
await self._exit_stack.__aexit__(*exc_info)
|
308
367
|
|
309
368
|
|
369
|
+
def create_status_server(task_worker: TaskWorker) -> FastAPI:
|
370
|
+
status_app = FastAPI()
|
371
|
+
|
372
|
+
@status_app.get("/status")
|
373
|
+
def status():
|
374
|
+
return {
|
375
|
+
"client_id": task_worker.client_id,
|
376
|
+
"started_at": task_worker.started_at.isoformat(),
|
377
|
+
"stopping": task_worker.stopping,
|
378
|
+
"limit": task_worker.limit,
|
379
|
+
"current": task_worker.current_tasks,
|
380
|
+
"available": task_worker.available_tasks,
|
381
|
+
"tasks": sorted(task_worker.task_keys),
|
382
|
+
"finished": task_worker.finished_task_runs,
|
383
|
+
"in_flight": {
|
384
|
+
key: {str(run): start.isoformat() for run, start in tasks.items()}
|
385
|
+
for key, tasks in task_worker.in_flight_task_runs.items()
|
386
|
+
},
|
387
|
+
}
|
388
|
+
|
389
|
+
return status_app
|
390
|
+
|
391
|
+
|
310
392
|
@sync_compatible
|
311
|
-
async def serve(
|
393
|
+
async def serve(
|
394
|
+
*tasks: Task, limit: Optional[int] = 10, status_server_port: Optional[int] = None
|
395
|
+
):
|
312
396
|
"""Serve the provided tasks so that their runs may be submitted to and executed.
|
313
397
|
in the engine. Tasks do not need to be within a flow run context to be submitted.
|
314
398
|
You must `.submit` the same task object that you pass to `serve`.
|
@@ -318,6 +402,9 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
|
|
318
402
|
given task, the task run will be submitted to the engine for execution.
|
319
403
|
- limit: The maximum number of tasks that can be run concurrently. Defaults to 10.
|
320
404
|
Pass `None` to remove the limit.
|
405
|
+
- status_server_port: An optional port on which to start an HTTP server
|
406
|
+
exposing status information about the task worker. If not provided, no
|
407
|
+
status server will run.
|
321
408
|
|
322
409
|
Example:
|
323
410
|
```python
|
@@ -339,6 +426,20 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
|
|
339
426
|
"""
|
340
427
|
task_worker = TaskWorker(*tasks, limit=limit)
|
341
428
|
|
429
|
+
status_server_task = None
|
430
|
+
if status_server_port is not None:
|
431
|
+
server = uvicorn.Server(
|
432
|
+
uvicorn.Config(
|
433
|
+
app=create_status_server(task_worker),
|
434
|
+
host="127.0.0.1",
|
435
|
+
port=status_server_port,
|
436
|
+
access_log=False,
|
437
|
+
log_level="warning",
|
438
|
+
)
|
439
|
+
)
|
440
|
+
loop = asyncio.get_event_loop()
|
441
|
+
status_server_task = loop.create_task(server.serve())
|
442
|
+
|
342
443
|
try:
|
343
444
|
await task_worker.start()
|
344
445
|
|
@@ -355,3 +456,11 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
|
|
355
456
|
|
356
457
|
except (asyncio.CancelledError, KeyboardInterrupt):
|
357
458
|
logger.info("Task worker interrupted, stopping...")
|
459
|
+
|
460
|
+
finally:
|
461
|
+
if status_server_task:
|
462
|
+
status_server_task.cancel()
|
463
|
+
try:
|
464
|
+
await status_server_task
|
465
|
+
except asyncio.CancelledError:
|
466
|
+
pass
|