prefect-client 3.0.0rc3__py3-none-any.whl → 3.0.0rc4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
prefect/__init__.py CHANGED
@@ -43,7 +43,6 @@ import prefect.runtime
43
43
 
44
44
  # Import modules that register types
45
45
  import prefect.serializers
46
- import prefect.blocks.kubernetes
47
46
  import prefect.blocks.notifications
48
47
  import prefect.blocks.system
49
48
 
@@ -1,5 +1,5 @@
1
1
  import asyncio
2
- from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
2
+ from typing import Any, Dict, Generic, Iterable, Optional, Type, TypeVar
3
3
 
4
4
  import orjson
5
5
  import websockets
@@ -21,7 +21,7 @@ class Subscription(Generic[S]):
21
21
  self,
22
22
  model: Type[S],
23
23
  path: str,
24
- keys: List[str],
24
+ keys: Iterable[str],
25
25
  client_id: Optional[str] = None,
26
26
  base_url: Optional[str] = None,
27
27
  ):
@@ -30,7 +30,7 @@ class Subscription(Generic[S]):
30
30
  base_url = base_url.replace("http", "ws", 1)
31
31
  self.subscription_url = f"{base_url}{path}"
32
32
 
33
- self.keys = keys
33
+ self.keys = list(keys)
34
34
 
35
35
  self._connect = websockets.connect(
36
36
  self.subscription_url,
prefect/flow_engine.py CHANGED
@@ -6,6 +6,7 @@ from contextlib import ExitStack, contextmanager
6
6
  from dataclasses import dataclass, field
7
7
  from typing import (
8
8
  Any,
9
+ AsyncGenerator,
9
10
  Callable,
10
11
  Coroutine,
11
12
  Dict,
@@ -50,12 +51,13 @@ from prefect.states import (
50
51
  return_value_to_state,
51
52
  )
52
53
  from prefect.utilities.asyncutils import run_coro_as_sync
53
- from prefect.utilities.callables import call_with_parameters
54
+ from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs
54
55
  from prefect.utilities.collections import visit_collection
55
56
  from prefect.utilities.engine import (
56
57
  _get_hook_name,
57
58
  _resolve_custom_flow_run_name,
58
59
  capture_sigterm,
60
+ link_state_to_result,
59
61
  propose_state_sync,
60
62
  resolve_to_final_result,
61
63
  )
@@ -632,6 +634,80 @@ async def run_flow_async(
632
634
  return engine.state if return_type == "state" else engine.result()
633
635
 
634
636
 
637
+ def run_generator_flow_sync(
638
+ flow: Flow[P, R],
639
+ flow_run: Optional[FlowRun] = None,
640
+ parameters: Optional[Dict[str, Any]] = None,
641
+ wait_for: Optional[Iterable[PrefectFuture]] = None,
642
+ return_type: Literal["state", "result"] = "result",
643
+ ) -> Generator[R, None, None]:
644
+ if return_type != "result":
645
+ raise ValueError("The return_type for a generator flow must be 'result'")
646
+
647
+ engine = FlowRunEngine[P, R](
648
+ flow=flow, parameters=parameters, flow_run=flow_run, wait_for=wait_for
649
+ )
650
+
651
+ with engine.start():
652
+ while engine.is_running():
653
+ with engine.run_context():
654
+ call_args, call_kwargs = parameters_to_args_kwargs(
655
+ flow.fn, engine.parameters or {}
656
+ )
657
+ gen = flow.fn(*call_args, **call_kwargs)
658
+ try:
659
+ while True:
660
+ gen_result = next(gen)
661
+ # link the current state to the result for dependency tracking
662
+ link_state_to_result(engine.state, gen_result)
663
+ yield gen_result
664
+ except StopIteration as exc:
665
+ engine.handle_success(exc.value)
666
+ except GeneratorExit as exc:
667
+ engine.handle_success(None)
668
+ gen.throw(exc)
669
+
670
+ return engine.result()
671
+
672
+
673
+ async def run_generator_flow_async(
674
+ flow: Flow[P, R],
675
+ flow_run: Optional[FlowRun] = None,
676
+ parameters: Optional[Dict[str, Any]] = None,
677
+ wait_for: Optional[Iterable[PrefectFuture]] = None,
678
+ return_type: Literal["state", "result"] = "result",
679
+ ) -> AsyncGenerator[R, None]:
680
+ if return_type != "result":
681
+ raise ValueError("The return_type for a generator flow must be 'result'")
682
+
683
+ engine = FlowRunEngine[P, R](
684
+ flow=flow, parameters=parameters, flow_run=flow_run, wait_for=wait_for
685
+ )
686
+
687
+ with engine.start():
688
+ while engine.is_running():
689
+ with engine.run_context():
690
+ call_args, call_kwargs = parameters_to_args_kwargs(
691
+ flow.fn, engine.parameters or {}
692
+ )
693
+ gen = flow.fn(*call_args, **call_kwargs)
694
+ try:
695
+ while True:
696
+ # can't use anext in Python < 3.10
697
+ gen_result = await gen.__anext__()
698
+ # link the current state to the result for dependency tracking
699
+ link_state_to_result(engine.state, gen_result)
700
+ yield gen_result
701
+ except (StopAsyncIteration, GeneratorExit) as exc:
702
+ engine.handle_success(None)
703
+ if isinstance(exc, GeneratorExit):
704
+ gen.throw(exc)
705
+
706
+ # async generators can't return, but we can raise failures here
707
+ if engine.state.is_failed():
708
+ engine.result()
709
+
710
+
635
711
  def run_flow(
636
712
  flow: Flow[P, R],
637
713
  flow_run: Optional[FlowRun] = None,
@@ -646,7 +722,11 @@ def run_flow(
646
722
  wait_for=wait_for,
647
723
  return_type=return_type,
648
724
  )
649
- if flow.isasync:
725
+ if flow.isasync and flow.isgenerator:
726
+ return run_generator_flow_async(**kwargs)
727
+ elif flow.isgenerator:
728
+ return run_generator_flow_sync(**kwargs)
729
+ elif flow.isasync:
650
730
  return run_flow_async(**kwargs)
651
731
  else:
652
732
  return run_flow_sync(**kwargs)
prefect/flows.py CHANGED
@@ -89,7 +89,6 @@ from prefect.task_runners import TaskRunner, ThreadPoolTaskRunner
89
89
  from prefect.types import BANNED_CHARACTERS, WITHOUT_BANNED_CHARACTERS
90
90
  from prefect.utilities.annotations import NotSet
91
91
  from prefect.utilities.asyncutils import (
92
- is_async_fn,
93
92
  run_sync_in_worker_thread,
94
93
  sync_compatible,
95
94
  )
@@ -289,7 +288,18 @@ class Flow(Generic[P, R]):
289
288
  self.description = description or inspect.getdoc(fn)
290
289
  update_wrapper(self, fn)
291
290
  self.fn = fn
292
- self.isasync = is_async_fn(self.fn)
291
+
292
+ # the flow is considered async if its function is async or an async
293
+ # generator
294
+ self.isasync = inspect.iscoroutinefunction(
295
+ self.fn
296
+ ) or inspect.isasyncgenfunction(self.fn)
297
+
298
+ # the flow is considered a generator if its function is a generator or
299
+ # an async generator
300
+ self.isgenerator = inspect.isgeneratorfunction(
301
+ self.fn
302
+ ) or inspect.isasyncgenfunction(self.fn)
293
303
 
294
304
  raise_for_reserved_arguments(self.fn, ["return_state", "wait_for"])
295
305
 
prefect/futures.py CHANGED
@@ -56,7 +56,7 @@ class PrefectFuture(abc.ABC):
56
56
  def wait(self, timeout: Optional[float] = None) -> None:
57
57
  ...
58
58
  """
59
- Wait for the task run to complete.
59
+ Wait for the task run to complete.
60
60
 
61
61
  If the task run has already completed, this method will return immediately.
62
62
 
@@ -163,6 +163,10 @@ class PrefectDistributedFuture(PrefectFuture):
163
163
  )
164
164
  return
165
165
 
166
+ # Ask for the instance of TaskRunWaiter _now_ so that it's already running and
167
+ # can catch the completion event if it happens before we start listening for it.
168
+ TaskRunWaiter.instance()
169
+
166
170
  # Read task run to see if it is still running
167
171
  async with get_client() as client:
168
172
  task_run = await client.read_task_run(task_run_id=self._task_run_id)
@@ -245,6 +249,10 @@ def resolve_futures_to_states(
245
249
  context={},
246
250
  )
247
251
 
252
+ # if no futures were found, return the original expression
253
+ if not futures:
254
+ return expr
255
+
248
256
  # Get final states for each future
249
257
  states = []
250
258
  for future in futures:
prefect/results.py CHANGED
@@ -38,7 +38,6 @@ from prefect.settings import (
38
38
  PREFECT_RESULTS_DEFAULT_SERIALIZER,
39
39
  PREFECT_RESULTS_PERSIST_BY_DEFAULT,
40
40
  PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK,
41
- default_result_storage_block_name,
42
41
  )
43
42
  from prefect.utilities.annotations import NotSet
44
43
  from prefect.utilities.asyncutils import sync_compatible
@@ -62,35 +61,15 @@ logger = get_logger("results")
62
61
  P = ParamSpec("P")
63
62
  R = TypeVar("R")
64
63
 
65
-
66
- @sync_compatible
67
- async def get_default_result_storage() -> ResultStorage:
68
- """
69
- Generate a default file system for result storage.
70
- """
71
- try:
72
- return await Block.load(PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value())
73
- except ValueError as e:
74
- if "Unable to find" not in str(e):
75
- raise e
76
- elif (
77
- PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value()
78
- == default_result_storage_block_name()
79
- ):
80
- return LocalFileSystem(basepath=PREFECT_LOCAL_STORAGE_PATH.value())
81
- else:
82
- raise
64
+ _default_storages: Dict[Tuple[str, str], WritableFileSystem] = {}
83
65
 
84
66
 
85
- _default_task_scheduling_storages: Dict[Tuple[str, str], WritableFileSystem] = {}
86
-
87
-
88
- async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
67
+ async def _get_or_create_default_storage(block_document_slug: str) -> ResultStorage:
89
68
  """
90
- Generate a default file system for background task parameter/result storage.
69
+ Generate a default file system for storage.
91
70
  """
92
71
  default_storage_name, storage_path = cache_key = (
93
- PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK.value(),
72
+ block_document_slug,
94
73
  PREFECT_LOCAL_STORAGE_PATH.value(),
95
74
  )
96
75
 
@@ -105,8 +84,8 @@ async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
105
84
  if block_type_slug == "local-file-system":
106
85
  block = LocalFileSystem(basepath=storage_path)
107
86
  else:
108
- raise Exception(
109
- "The default task storage block does not exist, but it is of type "
87
+ raise ValueError(
88
+ "The default storage block does not exist, but it is of type "
110
89
  f"'{block_type_slug}' which cannot be created implicitly. Please create "
111
90
  "the block manually."
112
91
  )
@@ -123,13 +102,32 @@ async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
123
102
  return block
124
103
 
125
104
  try:
126
- return _default_task_scheduling_storages[cache_key]
105
+ return _default_storages[cache_key]
127
106
  except KeyError:
128
107
  storage = await get_storage()
129
- _default_task_scheduling_storages[cache_key] = storage
108
+ _default_storages[cache_key] = storage
130
109
  return storage
131
110
 
132
111
 
112
+ @sync_compatible
113
+ async def get_or_create_default_result_storage() -> ResultStorage:
114
+ """
115
+ Generate a default file system for result storage.
116
+ """
117
+ return await _get_or_create_default_storage(
118
+ PREFECT_DEFAULT_RESULT_STORAGE_BLOCK.value()
119
+ )
120
+
121
+
122
+ async def get_or_create_default_task_scheduling_storage() -> ResultStorage:
123
+ """
124
+ Generate a default file system for background task parameter/result storage.
125
+ """
126
+ return await _get_or_create_default_storage(
127
+ PREFECT_TASK_SCHEDULING_DEFAULT_STORAGE_BLOCK.value()
128
+ )
129
+
130
+
133
131
  def get_default_result_serializer() -> ResultSerializer:
134
132
  """
135
133
  Generate a default file system for result storage.
@@ -210,7 +208,9 @@ class ResultFactory(BaseModel):
210
208
  kwargs.pop(key)
211
209
 
212
210
  # Apply defaults
213
- kwargs.setdefault("result_storage", await get_default_result_storage())
211
+ kwargs.setdefault(
212
+ "result_storage", await get_or_create_default_result_storage()
213
+ )
214
214
  kwargs.setdefault("result_serializer", get_default_result_serializer())
215
215
  kwargs.setdefault("persist_result", get_default_persist_setting())
216
216
  kwargs.setdefault("cache_result_in_memory", True)
@@ -280,7 +280,9 @@ class ResultFactory(BaseModel):
280
280
  """
281
281
  Create a new result factory for a task.
282
282
  """
283
- return await cls._from_task(task, get_default_result_storage, client=client)
283
+ return await cls._from_task(
284
+ task, get_or_create_default_result_storage, client=client
285
+ )
284
286
 
285
287
  @classmethod
286
288
  @inject_client
prefect/settings.py CHANGED
@@ -1541,7 +1541,7 @@ The maximum number of retries to queue for submission.
1541
1541
 
1542
1542
  PREFECT_TASK_SCHEDULING_PENDING_TASK_TIMEOUT = Setting(
1543
1543
  timedelta,
1544
- default=timedelta(seconds=30),
1544
+ default=timedelta(0),
1545
1545
  )
1546
1546
  """
1547
1547
  How long before a PENDING task are made available to another task worker. In practice,
prefect/task_engine.py CHANGED
@@ -417,9 +417,7 @@ class TaskRunEngine(Generic[P, R]):
417
417
  log_prints=log_prints,
418
418
  task_run=self.task_run,
419
419
  parameters=self.parameters,
420
- result_factory=run_coro_as_sync(
421
- ResultFactory.from_autonomous_task(self.task)
422
- ), # type: ignore
420
+ result_factory=run_coro_as_sync(ResultFactory.from_task(self.task)), # type: ignore
423
421
  client=client,
424
422
  )
425
423
  )
@@ -467,9 +465,6 @@ class TaskRunEngine(Generic[P, R]):
467
465
  extra_task_inputs=dependencies,
468
466
  )
469
467
  )
470
- self.logger.info(
471
- f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
472
- )
473
468
  # Emit an event to capture that the task run was in the `PENDING` state.
474
469
  self._last_event = emit_task_run_state_change_event(
475
470
  task_run=self.task_run,
@@ -478,6 +473,10 @@ class TaskRunEngine(Generic[P, R]):
478
473
  )
479
474
 
480
475
  with self.setup_run_context():
476
+ # setup_run_context might update the task run name, so log creation here
477
+ self.logger.info(
478
+ f"Created task run {self.task_run.name!r} for task {self.task.name!r}"
479
+ )
481
480
  yield self
482
481
 
483
482
  except Exception:
prefect/task_runs.py CHANGED
@@ -92,13 +92,18 @@ class TaskRunWaiter:
92
92
  raise RuntimeError("TaskRunWaiter must run on the global loop thread.")
93
93
 
94
94
  self._loop = loop_thread._loop
95
- self._consumer_task = self._loop.create_task(self._consume_events())
95
+
96
+ consumer_started = asyncio.Event()
97
+ self._consumer_task = self._loop.create_task(
98
+ self._consume_events(consumer_started)
99
+ )
100
+ asyncio.run_coroutine_threadsafe(consumer_started.wait(), self._loop)
96
101
 
97
102
  loop_thread.add_shutdown_call(create_call(self.stop))
98
103
  atexit.register(self.stop)
99
104
  self._started = True
100
105
 
101
- async def _consume_events(self):
106
+ async def _consume_events(self, consumer_started: asyncio.Event):
102
107
  async with get_events_subscriber(
103
108
  filter=EventFilter(
104
109
  event=EventNameFilter(
@@ -109,6 +114,7 @@ class TaskRunWaiter:
109
114
  )
110
115
  )
111
116
  ) as subscriber:
117
+ consumer_started.set()
112
118
  async for event in subscriber:
113
119
  try:
114
120
  self.logger.debug(
@@ -119,6 +125,7 @@ class TaskRunWaiter:
119
125
  "prefect.task-run.", ""
120
126
  )
121
127
  )
128
+
122
129
  with self._observed_completed_task_runs_lock:
123
130
  # Cache the task run ID for a short period of time to avoid
124
131
  # unnecessary waits
@@ -172,14 +179,21 @@ class TaskRunWaiter:
172
179
  # when the event is received
173
180
  instance._completion_events[task_run_id] = finished_event
174
181
 
175
- with anyio.move_on_after(delay=timeout):
176
- await from_async.wait_for_call_in_loop_thread(
177
- create_call(finished_event.wait)
178
- )
182
+ try:
183
+ # Now check one more time whether the task run arrived before we start to
184
+ # wait on it, in case it came in while we were setting up the event above.
185
+ with instance._observed_completed_task_runs_lock:
186
+ if task_run_id in instance._observed_completed_task_runs:
187
+ return
179
188
 
180
- with instance._completion_events_lock:
181
- # Remove the event from the cache after it has been waited on
182
- instance._completion_events.pop(task_run_id, None)
189
+ with anyio.move_on_after(delay=timeout):
190
+ await from_async.wait_for_call_in_loop_thread(
191
+ create_call(finished_event.wait)
192
+ )
193
+ finally:
194
+ with instance._completion_events_lock:
195
+ # Remove the event from the cache after it has been waited on
196
+ instance._completion_events.pop(task_run_id, None)
183
197
 
184
198
  @classmethod
185
199
  def instance(cls):
prefect/task_worker.py CHANGED
@@ -8,10 +8,14 @@ from concurrent.futures import ThreadPoolExecutor
8
8
  from contextlib import AsyncExitStack
9
9
  from contextvars import copy_context
10
10
  from typing import List, Optional
11
+ from uuid import UUID
11
12
 
12
13
  import anyio
13
14
  import anyio.abc
15
+ import pendulum
16
+ import uvicorn
14
17
  from exceptiongroup import BaseExceptionGroup # novermin
18
+ from fastapi import FastAPI
15
19
  from websockets.exceptions import InvalidStatusCode
16
20
 
17
21
  from prefect import Task
@@ -73,8 +77,9 @@ class TaskWorker:
73
77
  limit: Optional[int] = 10,
74
78
  ):
75
79
  self.tasks: List[Task] = list(tasks)
80
+ self.task_keys = set(t.task_key for t in tasks if isinstance(t, Task))
76
81
 
77
- self.started: bool = False
82
+ self._started_at: Optional[pendulum.DateTime] = None
78
83
  self.stopping: bool = False
79
84
 
80
85
  self._client = get_client()
@@ -89,10 +94,41 @@ class TaskWorker:
89
94
  self._executor = ThreadPoolExecutor(max_workers=limit if limit else None)
90
95
  self._limiter = anyio.CapacityLimiter(limit) if limit else None
91
96
 
97
+ self.in_flight_task_runs: dict[str, dict[UUID, pendulum.DateTime]] = {
98
+ task_key: {} for task_key in self.task_keys
99
+ }
100
+ self.finished_task_runs: dict[str, int] = {
101
+ task_key: 0 for task_key in self.task_keys
102
+ }
103
+
92
104
  @property
93
- def _client_id(self) -> str:
105
+ def client_id(self) -> str:
94
106
  return f"{socket.gethostname()}-{os.getpid()}"
95
107
 
108
+ @property
109
+ def started_at(self) -> Optional[pendulum.DateTime]:
110
+ return self._started_at
111
+
112
+ @property
113
+ def started(self) -> bool:
114
+ return self._started_at is not None
115
+
116
+ @property
117
+ def limit(self) -> Optional[int]:
118
+ return int(self._limiter.total_tokens) if self._limiter else None
119
+
120
+ @property
121
+ def current_tasks(self) -> Optional[int]:
122
+ return (
123
+ int(self._limiter.borrowed_tokens)
124
+ if self._limiter
125
+ else sum(len(runs) for runs in self.in_flight_task_runs.values())
126
+ )
127
+
128
+ @property
129
+ def available_tasks(self) -> Optional[int]:
130
+ return int(self._limiter.available_tokens) if self._limiter else None
131
+
96
132
  def handle_sigterm(self, signum, frame):
97
133
  """
98
134
  Shuts down the task worker when a SIGTERM is received.
@@ -133,11 +169,31 @@ class TaskWorker:
133
169
  " calling .start()"
134
170
  )
135
171
 
136
- self.started = False
172
+ self._started_at = None
137
173
  self.stopping = True
138
174
 
139
175
  raise StopTaskWorker
140
176
 
177
+ async def _acquire_token(self, task_run_id: UUID) -> bool:
178
+ try:
179
+ if self._limiter:
180
+ await self._limiter.acquire_on_behalf_of(task_run_id)
181
+ except RuntimeError:
182
+ logger.debug(f"Token already acquired for task run: {task_run_id!r}")
183
+ return False
184
+
185
+ return True
186
+
187
+ def _release_token(self, task_run_id: UUID) -> bool:
188
+ try:
189
+ if self._limiter:
190
+ self._limiter.release_on_behalf_of(task_run_id)
191
+ except RuntimeError:
192
+ logger.debug(f"No token to release for task run: {task_run_id!r}")
193
+ return False
194
+
195
+ return True
196
+
141
197
  async def _subscribe_to_task_scheduling(self):
142
198
  base_url = PREFECT_API_URL.value()
143
199
  if base_url is None:
@@ -146,24 +202,26 @@ class TaskWorker:
146
202
  "Task workers are not compatible with the ephemeral API."
147
203
  )
148
204
  task_keys_repr = " | ".join(
149
- t.task_key.split(".")[-1].split("-")[0] for t in self.tasks
205
+ task_key.split(".")[-1].split("-")[0] for task_key in sorted(self.task_keys)
150
206
  )
151
207
  logger.info(f"Subscribing to runs of task(s): {task_keys_repr}")
152
208
  async for task_run in Subscription(
153
209
  model=TaskRun,
154
210
  path="/task_runs/subscriptions/scheduled",
155
- keys=[task.task_key for task in self.tasks],
156
- client_id=self._client_id,
211
+ keys=self.task_keys,
212
+ client_id=self.client_id,
157
213
  base_url=base_url,
158
214
  ):
159
215
  logger.info(f"Received task run: {task_run.id} - {task_run.name}")
160
- if self._limiter:
161
- await self._limiter.acquire_on_behalf_of(task_run.id)
162
- self._runs_task_group.start_soon(
163
- self._safe_submit_scheduled_task_run, task_run
164
- )
216
+
217
+ token_acquired = await self._acquire_token(task_run.id)
218
+ if token_acquired:
219
+ self._runs_task_group.start_soon(
220
+ self._safe_submit_scheduled_task_run, task_run
221
+ )
165
222
 
166
223
  async def _safe_submit_scheduled_task_run(self, task_run: TaskRun):
224
+ self.in_flight_task_runs[task_run.task_key][task_run.id] = pendulum.now()
167
225
  try:
168
226
  await self._submit_scheduled_task_run(task_run)
169
227
  except BaseException as exc:
@@ -172,8 +230,9 @@ class TaskWorker:
172
230
  exc_info=exc,
173
231
  )
174
232
  finally:
175
- if self._limiter:
176
- self._limiter.release_on_behalf_of(task_run.id)
233
+ self.in_flight_task_runs[task_run.task_key].pop(task_run.id, None)
234
+ self.finished_task_runs[task_run.task_key] += 1
235
+ self._release_token(task_run.id)
177
236
 
178
237
  async def _submit_scheduled_task_run(self, task_run: TaskRun):
179
238
  logger.debug(
@@ -284,9 +343,9 @@ class TaskWorker:
284
343
  async def execute_task_run(self, task_run: TaskRun):
285
344
  """Execute a task run in the task worker."""
286
345
  async with self if not self.started else asyncnullcontext():
287
- if self._limiter:
288
- await self._limiter.acquire_on_behalf_of(task_run.id)
289
- await self._safe_submit_scheduled_task_run(task_run)
346
+ token_acquired = await self._acquire_token(task_run.id)
347
+ if token_acquired:
348
+ await self._safe_submit_scheduled_task_run(task_run)
290
349
 
291
350
  async def __aenter__(self):
292
351
  logger.debug("Starting task worker...")
@@ -298,17 +357,42 @@ class TaskWorker:
298
357
  await self._exit_stack.enter_async_context(self._runs_task_group)
299
358
  self._exit_stack.enter_context(self._executor)
300
359
 
301
- self.started = True
360
+ self._started_at = pendulum.now()
302
361
  return self
303
362
 
304
363
  async def __aexit__(self, *exc_info):
305
364
  logger.debug("Stopping task worker...")
306
- self.started = False
365
+ self._started_at = None
307
366
  await self._exit_stack.__aexit__(*exc_info)
308
367
 
309
368
 
369
+ def create_status_server(task_worker: TaskWorker) -> FastAPI:
370
+ status_app = FastAPI()
371
+
372
+ @status_app.get("/status")
373
+ def status():
374
+ return {
375
+ "client_id": task_worker.client_id,
376
+ "started_at": task_worker.started_at.isoformat(),
377
+ "stopping": task_worker.stopping,
378
+ "limit": task_worker.limit,
379
+ "current": task_worker.current_tasks,
380
+ "available": task_worker.available_tasks,
381
+ "tasks": sorted(task_worker.task_keys),
382
+ "finished": task_worker.finished_task_runs,
383
+ "in_flight": {
384
+ key: {str(run): start.isoformat() for run, start in tasks.items()}
385
+ for key, tasks in task_worker.in_flight_task_runs.items()
386
+ },
387
+ }
388
+
389
+ return status_app
390
+
391
+
310
392
  @sync_compatible
311
- async def serve(*tasks: Task, limit: Optional[int] = 10):
393
+ async def serve(
394
+ *tasks: Task, limit: Optional[int] = 10, status_server_port: Optional[int] = None
395
+ ):
312
396
  """Serve the provided tasks so that their runs may be submitted to and executed.
313
397
  in the engine. Tasks do not need to be within a flow run context to be submitted.
314
398
  You must `.submit` the same task object that you pass to `serve`.
@@ -318,6 +402,9 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
318
402
  given task, the task run will be submitted to the engine for execution.
319
403
  - limit: The maximum number of tasks that can be run concurrently. Defaults to 10.
320
404
  Pass `None` to remove the limit.
405
+ - status_server_port: An optional port on which to start an HTTP server
406
+ exposing status information about the task worker. If not provided, no
407
+ status server will run.
321
408
 
322
409
  Example:
323
410
  ```python
@@ -339,6 +426,20 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
339
426
  """
340
427
  task_worker = TaskWorker(*tasks, limit=limit)
341
428
 
429
+ status_server_task = None
430
+ if status_server_port is not None:
431
+ server = uvicorn.Server(
432
+ uvicorn.Config(
433
+ app=create_status_server(task_worker),
434
+ host="127.0.0.1",
435
+ port=status_server_port,
436
+ access_log=False,
437
+ log_level="warning",
438
+ )
439
+ )
440
+ loop = asyncio.get_event_loop()
441
+ status_server_task = loop.create_task(server.serve())
442
+
342
443
  try:
343
444
  await task_worker.start()
344
445
 
@@ -355,3 +456,11 @@ async def serve(*tasks: Task, limit: Optional[int] = 10):
355
456
 
356
457
  except (asyncio.CancelledError, KeyboardInterrupt):
357
458
  logger.info("Task worker interrupted, stopping...")
459
+
460
+ finally:
461
+ if status_server_task:
462
+ status_server_task.cancel()
463
+ try:
464
+ await status_server_task
465
+ except asyncio.CancelledError:
466
+ pass