stagegate 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stagegate/__init__.py +20 -0
- stagegate/_records.py +150 -0
- stagegate/_states.py +34 -0
- stagegate/_wait_utils.py +97 -0
- stagegate/exceptions.py +13 -0
- stagegate/handles.py +216 -0
- stagegate/pipeline.py +122 -0
- stagegate/scheduler.py +482 -0
- stagegate/wait.py +13 -0
- stagegate-0.1.0.dist-info/METADATA +226 -0
- stagegate-0.1.0.dist-info/RECORD +13 -0
- stagegate-0.1.0.dist-info/WHEEL +4 -0
- stagegate-0.1.0.dist-info/licenses/LICENSE +21 -0
stagegate/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
"""Public package namespace for stage-aware local pipeline execution."""
|
|
2
|
+
|
|
3
|
+
from .exceptions import CancelledError, UnknownResourceError, UnschedulableTaskError
|
|
4
|
+
from .handles import PipelineHandle, TaskHandle
|
|
5
|
+
from .pipeline import Pipeline
|
|
6
|
+
from .scheduler import Scheduler
|
|
7
|
+
from .wait import ALL_COMPLETED, FIRST_COMPLETED, FIRST_EXCEPTION
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"Scheduler",
|
|
11
|
+
"Pipeline",
|
|
12
|
+
"TaskHandle",
|
|
13
|
+
"PipelineHandle",
|
|
14
|
+
"FIRST_COMPLETED",
|
|
15
|
+
"FIRST_EXCEPTION",
|
|
16
|
+
"ALL_COMPLETED",
|
|
17
|
+
"CancelledError",
|
|
18
|
+
"UnknownResourceError",
|
|
19
|
+
"UnschedulableTaskError",
|
|
20
|
+
]
|
stagegate/_records.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
"""Internal scheduler record and queue-entry definitions."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections import deque
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from ._states import PipelineState, SchedulerState, TaskState
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from .pipeline import Pipeline
|
|
14
|
+
from .scheduler import Scheduler
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
TerminalTaskState = frozenset(
|
|
18
|
+
{
|
|
19
|
+
TaskState.SUCCEEDED,
|
|
20
|
+
TaskState.FAILED,
|
|
21
|
+
TaskState.CANCELLED,
|
|
22
|
+
}
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
TerminalPipelineState = frozenset(
|
|
26
|
+
{
|
|
27
|
+
PipelineState.SUCCEEDED,
|
|
28
|
+
PipelineState.FAILED,
|
|
29
|
+
PipelineState.CANCELLED,
|
|
30
|
+
}
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(order=True, frozen=True, slots=True)
|
|
35
|
+
class PipelineQueueEntry:
|
|
36
|
+
"""FIFO queue entry for pipeline admission."""
|
|
37
|
+
|
|
38
|
+
enqueue_seq: int
|
|
39
|
+
record: PipelineRecord = field(compare=False)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass(order=True, frozen=True, slots=True)
|
|
43
|
+
class TaskPriorityKey:
|
|
44
|
+
"""Deterministic ordering key for queued tasks."""
|
|
45
|
+
|
|
46
|
+
neg_stage: int
|
|
47
|
+
pipeline_enqueue_seq: int
|
|
48
|
+
pipeline_local_task_seq: int
|
|
49
|
+
global_task_submit_seq: int
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass(order=True, frozen=True, slots=True)
|
|
53
|
+
class TaskQueueEntry:
|
|
54
|
+
"""Priority-queue entry for task admission control."""
|
|
55
|
+
|
|
56
|
+
priority: TaskPriorityKey
|
|
57
|
+
record: TaskRecord = field(compare=False)
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def from_record(cls, record: TaskRecord) -> TaskQueueEntry:
|
|
61
|
+
"""Build a queue entry from the record's canonical priority key."""
|
|
62
|
+
|
|
63
|
+
return cls(priority=record.priority_key(), record=record)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass(order=True, frozen=True, slots=True)
|
|
67
|
+
class ReadyQueueEntry:
|
|
68
|
+
"""FIFO queue entry for tasks admitted but not yet started by a worker."""
|
|
69
|
+
|
|
70
|
+
ready_seq: int
|
|
71
|
+
record: TaskRecord = field(compare=False)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass(slots=True)
|
|
75
|
+
class PipelineRecord:
|
|
76
|
+
"""Mutable source-of-truth record for one pipeline instance."""
|
|
77
|
+
|
|
78
|
+
scheduler: Scheduler
|
|
79
|
+
pipeline: Pipeline
|
|
80
|
+
pipeline_id: int
|
|
81
|
+
enqueue_seq: int
|
|
82
|
+
state: PipelineState = PipelineState.QUEUED
|
|
83
|
+
stage_index: int = 0
|
|
84
|
+
next_task_seq: int = 0
|
|
85
|
+
result_value: Any = None
|
|
86
|
+
exception: BaseException | None = None
|
|
87
|
+
coordinator_thread_ident: int | None = None
|
|
88
|
+
task_records: list[TaskRecord] = field(default_factory=list)
|
|
89
|
+
|
|
90
|
+
def is_terminal(self) -> bool:
|
|
91
|
+
"""Return whether the pipeline is in a terminal state."""
|
|
92
|
+
|
|
93
|
+
return self.state in TerminalPipelineState
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass(slots=True)
|
|
97
|
+
class TaskRecord:
|
|
98
|
+
"""Mutable source-of-truth record for one scheduled task."""
|
|
99
|
+
|
|
100
|
+
scheduler: Scheduler
|
|
101
|
+
pipeline_record: PipelineRecord
|
|
102
|
+
task_id: int
|
|
103
|
+
fn: Callable[..., Any]
|
|
104
|
+
resources_required: dict[str, int | float]
|
|
105
|
+
args: tuple[Any, ...] = ()
|
|
106
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
107
|
+
name: str | None = None
|
|
108
|
+
stage_snapshot: int = 0
|
|
109
|
+
pipeline_enqueue_seq: int = 0
|
|
110
|
+
pipeline_local_task_seq: int = 0
|
|
111
|
+
global_task_submit_seq: int = 0
|
|
112
|
+
state: TaskState = TaskState.QUEUED
|
|
113
|
+
result_value: Any = None
|
|
114
|
+
exception: BaseException | None = None
|
|
115
|
+
ready_seq: int | None = None
|
|
116
|
+
worker_thread_ident: int | None = None
|
|
117
|
+
|
|
118
|
+
def priority_key(self) -> TaskPriorityKey:
|
|
119
|
+
"""Build the ordering key used by the task priority queue."""
|
|
120
|
+
|
|
121
|
+
return TaskPriorityKey(
|
|
122
|
+
neg_stage=-self.stage_snapshot,
|
|
123
|
+
pipeline_enqueue_seq=self.pipeline_enqueue_seq,
|
|
124
|
+
pipeline_local_task_seq=self.pipeline_local_task_seq,
|
|
125
|
+
global_task_submit_seq=self.global_task_submit_seq,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
def is_terminal(self) -> bool:
|
|
129
|
+
"""Return whether the task is in a terminal state."""
|
|
130
|
+
|
|
131
|
+
return self.state in TerminalTaskState
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@dataclass(slots=True)
|
|
135
|
+
class SchedulerRuntime:
|
|
136
|
+
"""Mutable internal scheduler-wide runtime bookkeeping."""
|
|
137
|
+
|
|
138
|
+
state: SchedulerState = SchedulerState.OPEN
|
|
139
|
+
next_pipeline_id: int = 0
|
|
140
|
+
next_pipeline_enqueue_seq: int = 0
|
|
141
|
+
next_task_id: int = 0
|
|
142
|
+
next_global_task_submit_seq: int = 0
|
|
143
|
+
next_ready_seq: int = 0
|
|
144
|
+
admitted_task_count: int = 0
|
|
145
|
+
pipeline_queue: deque[PipelineQueueEntry] = field(default_factory=deque)
|
|
146
|
+
task_queue: list[TaskQueueEntry] = field(default_factory=list)
|
|
147
|
+
ready_queue: deque[ReadyQueueEntry] = field(default_factory=deque)
|
|
148
|
+
resources_in_use: dict[str, int | float] = field(default_factory=dict)
|
|
149
|
+
pipeline_records: dict[int, PipelineRecord] = field(default_factory=dict)
|
|
150
|
+
task_records: dict[int, TaskRecord] = field(default_factory=dict)
|
stagegate/_states.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
"""Internal state definitions for scheduler records."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from enum import Enum
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SchedulerState(Enum):
|
|
9
|
+
"""Internal lifecycle for scheduler shutdown progression."""
|
|
10
|
+
|
|
11
|
+
OPEN = "open"
|
|
12
|
+
SHUTTING_DOWN = "shutting_down"
|
|
13
|
+
CLOSED = "closed"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PipelineState(Enum):
|
|
17
|
+
"""Internal pipeline execution states."""
|
|
18
|
+
|
|
19
|
+
QUEUED = "queued"
|
|
20
|
+
RUNNING = "running"
|
|
21
|
+
SUCCEEDED = "succeeded"
|
|
22
|
+
FAILED = "failed"
|
|
23
|
+
CANCELLED = "cancelled"
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TaskState(Enum):
|
|
27
|
+
"""Internal task execution states."""
|
|
28
|
+
|
|
29
|
+
QUEUED = "queued"
|
|
30
|
+
READY = "ready"
|
|
31
|
+
RUNNING = "running"
|
|
32
|
+
SUCCEEDED = "succeeded"
|
|
33
|
+
FAILED = "failed"
|
|
34
|
+
CANCELLED = "cancelled"
|
stagegate/_wait_utils.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
"""Internal helpers for owner-scoped wait APIs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from collections.abc import Callable, Iterable
|
|
7
|
+
from typing import TypeVar
|
|
8
|
+
|
|
9
|
+
from ._states import PipelineState, TaskState
|
|
10
|
+
from .handles import PipelineHandle, TaskHandle
|
|
11
|
+
from .wait import ALL_COMPLETED, FIRST_COMPLETED, WAIT_CONDITIONS
|
|
12
|
+
|
|
13
|
+
HandleT = TypeVar("HandleT", TaskHandle, PipelineHandle)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def validate_wait_request(
|
|
17
|
+
handles: Iterable[HandleT],
|
|
18
|
+
*,
|
|
19
|
+
expected_type: type[HandleT],
|
|
20
|
+
owner_check: Callable[[HandleT], bool],
|
|
21
|
+
timeout: float | None,
|
|
22
|
+
return_when: str,
|
|
23
|
+
) -> set[HandleT]:
|
|
24
|
+
"""Normalize and validate public wait() inputs."""
|
|
25
|
+
|
|
26
|
+
if timeout is not None and timeout < 0:
|
|
27
|
+
raise ValueError("timeout must be None or a non-negative number")
|
|
28
|
+
if return_when not in WAIT_CONDITIONS:
|
|
29
|
+
raise ValueError("invalid return_when")
|
|
30
|
+
|
|
31
|
+
candidates = tuple(handles)
|
|
32
|
+
if not candidates:
|
|
33
|
+
raise ValueError("handles must not be empty")
|
|
34
|
+
|
|
35
|
+
for handle in candidates:
|
|
36
|
+
if not isinstance(handle, expected_type):
|
|
37
|
+
raise TypeError("wrong handle type")
|
|
38
|
+
if not owner_check(handle):
|
|
39
|
+
raise ValueError("handle does not belong to this owner")
|
|
40
|
+
|
|
41
|
+
normalized = set(candidates)
|
|
42
|
+
return normalized
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def split_done_pending(handles: set[HandleT]) -> tuple[set[HandleT], set[HandleT]]:
|
|
46
|
+
"""Partition handles into terminal and non-terminal subsets."""
|
|
47
|
+
|
|
48
|
+
done: set[HandleT] = set()
|
|
49
|
+
pending: set[HandleT] = set()
|
|
50
|
+
for handle in handles:
|
|
51
|
+
if handle.done():
|
|
52
|
+
done.add(handle)
|
|
53
|
+
else:
|
|
54
|
+
pending.add(handle)
|
|
55
|
+
return done, pending
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def should_return(
|
|
59
|
+
*,
|
|
60
|
+
done: set[HandleT],
|
|
61
|
+
pending: set[HandleT],
|
|
62
|
+
return_when: str,
|
|
63
|
+
) -> bool:
|
|
64
|
+
"""Evaluate the public wait-return condition."""
|
|
65
|
+
|
|
66
|
+
if return_when == ALL_COMPLETED:
|
|
67
|
+
return not pending
|
|
68
|
+
if return_when == FIRST_COMPLETED:
|
|
69
|
+
return bool(done)
|
|
70
|
+
if any(_is_failed(handle) for handle in done):
|
|
71
|
+
return True
|
|
72
|
+
return not pending
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def monotonic_deadline(timeout: float | None) -> float | None:
|
|
76
|
+
"""Convert a public timeout value into an absolute deadline."""
|
|
77
|
+
|
|
78
|
+
if timeout is None:
|
|
79
|
+
return None
|
|
80
|
+
return time.monotonic() + timeout
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def remaining_timeout(deadline: float | None) -> float | None:
|
|
84
|
+
"""Compute remaining timeout for a condition wait."""
|
|
85
|
+
|
|
86
|
+
if deadline is None:
|
|
87
|
+
return None
|
|
88
|
+
remaining = deadline - time.monotonic()
|
|
89
|
+
if remaining <= 0:
|
|
90
|
+
return 0.0
|
|
91
|
+
return remaining
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def _is_failed(handle: HandleT) -> bool:
|
|
95
|
+
if isinstance(handle, TaskHandle):
|
|
96
|
+
return handle._record.state is TaskState.FAILED
|
|
97
|
+
return handle._record.state is PipelineState.FAILED
|
stagegate/exceptions.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Public exceptions for stagegate."""
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CancelledError(Exception):
|
|
5
|
+
"""Raised when result() or exception() is requested from a cancelled handle."""
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class UnknownResourceError(ValueError):
|
|
9
|
+
"""Raised when a task requests a resource label unknown to the scheduler."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class UnschedulableTaskError(ValueError):
|
|
13
|
+
"""Raised when a single task can never fit within scheduler capacity."""
|
stagegate/handles.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""Public handle skeletons."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import time
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
from ._records import TerminalPipelineState, TerminalTaskState
|
|
9
|
+
from ._states import PipelineState, TaskState
|
|
10
|
+
from .exceptions import CancelledError
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from ._records import PipelineRecord, TaskRecord
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _validate_timeout(timeout: float | None) -> float | None:
|
|
17
|
+
if timeout is None:
|
|
18
|
+
return None
|
|
19
|
+
if timeout < 0:
|
|
20
|
+
raise ValueError("timeout must be None or a non-negative number")
|
|
21
|
+
return timeout
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _remaining_timeout(deadline: float) -> float:
|
|
25
|
+
remaining = deadline - time.monotonic()
|
|
26
|
+
if remaining <= 0:
|
|
27
|
+
raise TimeoutError
|
|
28
|
+
return remaining
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TaskHandle:
|
|
32
|
+
"""Handle for a task submitted from a pipeline."""
|
|
33
|
+
|
|
34
|
+
__slots__ = ("_record",)
|
|
35
|
+
|
|
36
|
+
def __init__(self, record: TaskRecord) -> None:
|
|
37
|
+
self._record = record
|
|
38
|
+
|
|
39
|
+
def __eq__(self, other: object) -> bool:
|
|
40
|
+
if not isinstance(other, TaskHandle):
|
|
41
|
+
return NotImplemented
|
|
42
|
+
return (
|
|
43
|
+
self._record.task_id == other._record.task_id
|
|
44
|
+
and self._record.scheduler is other._record.scheduler
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def __hash__(self) -> int:
|
|
48
|
+
return hash((TaskHandle, id(self._record.scheduler), self._record.task_id))
|
|
49
|
+
|
|
50
|
+
def __repr__(self) -> str:
|
|
51
|
+
return (
|
|
52
|
+
f"TaskHandle(scheduler=0x{id(self._record.scheduler):x}, "
|
|
53
|
+
f"task_id={self._record.task_id})"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def cancel(self) -> bool:
|
|
57
|
+
"""Cancel the task if it has not started yet."""
|
|
58
|
+
scheduler = self._record.scheduler
|
|
59
|
+
with scheduler._condition:
|
|
60
|
+
return scheduler._cancel_task_if_possible_locked(self._record)
|
|
61
|
+
|
|
62
|
+
def done(self) -> bool:
|
|
63
|
+
"""Return whether the task is terminal."""
|
|
64
|
+
scheduler = self._record.scheduler
|
|
65
|
+
with scheduler._condition:
|
|
66
|
+
return self._record.state in TerminalTaskState
|
|
67
|
+
|
|
68
|
+
def running(self) -> bool:
|
|
69
|
+
"""Return whether the task callable is actively running."""
|
|
70
|
+
scheduler = self._record.scheduler
|
|
71
|
+
with scheduler._condition:
|
|
72
|
+
return self._record.state is TaskState.RUNNING
|
|
73
|
+
|
|
74
|
+
def cancelled(self) -> bool:
|
|
75
|
+
"""Return whether the task ended in the cancelled state."""
|
|
76
|
+
scheduler = self._record.scheduler
|
|
77
|
+
with scheduler._condition:
|
|
78
|
+
return self._record.state is TaskState.CANCELLED
|
|
79
|
+
|
|
80
|
+
def result(self, timeout: float | None = None) -> Any:
|
|
81
|
+
"""Return the task result or raise its stored terminal outcome."""
|
|
82
|
+
timeout = _validate_timeout(timeout)
|
|
83
|
+
scheduler = self._record.scheduler
|
|
84
|
+
deadline = None if timeout is None else time.monotonic() + timeout
|
|
85
|
+
with scheduler._condition:
|
|
86
|
+
while self._record.state not in TerminalTaskState:
|
|
87
|
+
wait_timeout = (
|
|
88
|
+
None if deadline is None else _remaining_timeout(deadline)
|
|
89
|
+
)
|
|
90
|
+
scheduler._condition.wait(wait_timeout)
|
|
91
|
+
|
|
92
|
+
if self._record.state is TaskState.SUCCEEDED:
|
|
93
|
+
return self._record.result_value
|
|
94
|
+
if self._record.state is TaskState.FAILED:
|
|
95
|
+
assert self._record.exception is not None
|
|
96
|
+
raise self._record.exception
|
|
97
|
+
raise CancelledError("result() was requested from a cancelled task handle")
|
|
98
|
+
|
|
99
|
+
def exception(self, timeout: float | None = None) -> BaseException | None:
|
|
100
|
+
"""Return the task exception object, if any."""
|
|
101
|
+
timeout = _validate_timeout(timeout)
|
|
102
|
+
scheduler = self._record.scheduler
|
|
103
|
+
deadline = None if timeout is None else time.monotonic() + timeout
|
|
104
|
+
with scheduler._condition:
|
|
105
|
+
while self._record.state not in TerminalTaskState:
|
|
106
|
+
wait_timeout = (
|
|
107
|
+
None if deadline is None else _remaining_timeout(deadline)
|
|
108
|
+
)
|
|
109
|
+
scheduler._condition.wait(wait_timeout)
|
|
110
|
+
|
|
111
|
+
if self._record.state is TaskState.SUCCEEDED:
|
|
112
|
+
return None
|
|
113
|
+
if self._record.state is TaskState.FAILED:
|
|
114
|
+
assert self._record.exception is not None
|
|
115
|
+
return self._record.exception
|
|
116
|
+
raise CancelledError(
|
|
117
|
+
"exception() was requested from a cancelled task handle"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class PipelineHandle:
|
|
122
|
+
"""Handle for a pipeline submitted to a scheduler."""
|
|
123
|
+
|
|
124
|
+
__slots__ = ("_record",)
|
|
125
|
+
|
|
126
|
+
def __init__(self, record: PipelineRecord) -> None:
|
|
127
|
+
self._record = record
|
|
128
|
+
|
|
129
|
+
def __eq__(self, other: object) -> bool:
|
|
130
|
+
if not isinstance(other, PipelineHandle):
|
|
131
|
+
return NotImplemented
|
|
132
|
+
return (
|
|
133
|
+
self._record.pipeline_id == other._record.pipeline_id
|
|
134
|
+
and self._record.scheduler is other._record.scheduler
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def __hash__(self) -> int:
|
|
138
|
+
return hash(
|
|
139
|
+
(PipelineHandle, id(self._record.scheduler), self._record.pipeline_id)
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def __repr__(self) -> str:
|
|
143
|
+
return (
|
|
144
|
+
f"PipelineHandle(scheduler=0x{id(self._record.scheduler):x}, "
|
|
145
|
+
f"pipeline_id={self._record.pipeline_id})"
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
def cancel(self) -> bool:
|
|
149
|
+
"""Cancel the pipeline if it has not started yet."""
|
|
150
|
+
scheduler = self._record.scheduler
|
|
151
|
+
with scheduler._condition:
|
|
152
|
+
if self._record.state is PipelineState.QUEUED:
|
|
153
|
+
self._record.state = PipelineState.CANCELLED
|
|
154
|
+
scheduler._notify_state_change_locked()
|
|
155
|
+
return True
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
def done(self) -> bool:
|
|
159
|
+
"""Return whether the pipeline is terminal."""
|
|
160
|
+
scheduler = self._record.scheduler
|
|
161
|
+
with scheduler._condition:
|
|
162
|
+
return self._record.state in TerminalPipelineState
|
|
163
|
+
|
|
164
|
+
def running(self) -> bool:
|
|
165
|
+
"""Return whether the pipeline ``run()`` method is active."""
|
|
166
|
+
scheduler = self._record.scheduler
|
|
167
|
+
with scheduler._condition:
|
|
168
|
+
return self._record.state is PipelineState.RUNNING
|
|
169
|
+
|
|
170
|
+
def cancelled(self) -> bool:
|
|
171
|
+
"""Return whether the pipeline ended in the cancelled state."""
|
|
172
|
+
scheduler = self._record.scheduler
|
|
173
|
+
with scheduler._condition:
|
|
174
|
+
return self._record.state is PipelineState.CANCELLED
|
|
175
|
+
|
|
176
|
+
def result(self, timeout: float | None = None) -> Any:
|
|
177
|
+
"""Return the pipeline result or raise its stored terminal outcome."""
|
|
178
|
+
timeout = _validate_timeout(timeout)
|
|
179
|
+
scheduler = self._record.scheduler
|
|
180
|
+
deadline = None if timeout is None else time.monotonic() + timeout
|
|
181
|
+
with scheduler._condition:
|
|
182
|
+
while self._record.state not in TerminalPipelineState:
|
|
183
|
+
wait_timeout = (
|
|
184
|
+
None if deadline is None else _remaining_timeout(deadline)
|
|
185
|
+
)
|
|
186
|
+
scheduler._condition.wait(wait_timeout)
|
|
187
|
+
|
|
188
|
+
if self._record.state is PipelineState.SUCCEEDED:
|
|
189
|
+
return self._record.result_value
|
|
190
|
+
if self._record.state is PipelineState.FAILED:
|
|
191
|
+
assert self._record.exception is not None
|
|
192
|
+
raise self._record.exception
|
|
193
|
+
raise CancelledError(
|
|
194
|
+
"result() was requested from a cancelled pipeline handle"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def exception(self, timeout: float | None = None) -> BaseException | None:
|
|
198
|
+
"""Return the pipeline exception object, if any."""
|
|
199
|
+
timeout = _validate_timeout(timeout)
|
|
200
|
+
scheduler = self._record.scheduler
|
|
201
|
+
deadline = None if timeout is None else time.monotonic() + timeout
|
|
202
|
+
with scheduler._condition:
|
|
203
|
+
while self._record.state not in TerminalPipelineState:
|
|
204
|
+
wait_timeout = (
|
|
205
|
+
None if deadline is None else _remaining_timeout(deadline)
|
|
206
|
+
)
|
|
207
|
+
scheduler._condition.wait(wait_timeout)
|
|
208
|
+
|
|
209
|
+
if self._record.state is PipelineState.SUCCEEDED:
|
|
210
|
+
return None
|
|
211
|
+
if self._record.state is PipelineState.FAILED:
|
|
212
|
+
assert self._record.exception is not None
|
|
213
|
+
return self._record.exception
|
|
214
|
+
raise CancelledError(
|
|
215
|
+
"exception() was requested from a cancelled pipeline handle"
|
|
216
|
+
)
|
stagegate/pipeline.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
"""Pipeline base class and task-builder skeleton."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
import threading
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
from .handles import TaskHandle
|
|
10
|
+
from ._states import PipelineState
|
|
11
|
+
from ._wait_utils import (
|
|
12
|
+
monotonic_deadline,
|
|
13
|
+
remaining_timeout,
|
|
14
|
+
should_return,
|
|
15
|
+
split_done_pending,
|
|
16
|
+
validate_wait_request,
|
|
17
|
+
)
|
|
18
|
+
from .wait import ALL_COMPLETED
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from collections.abc import Callable, Iterable
|
|
22
|
+
|
|
23
|
+
from ._records import PipelineRecord
|
|
24
|
+
from .scheduler import Scheduler
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True, slots=True)
|
|
28
|
+
class TaskBuilder:
|
|
29
|
+
"""Factory object returned by ``Pipeline.task(...)``."""
|
|
30
|
+
|
|
31
|
+
pipeline: Pipeline
|
|
32
|
+
fn: Callable[..., Any]
|
|
33
|
+
resources: dict[str, int | float]
|
|
34
|
+
args: tuple[Any, ...] = ()
|
|
35
|
+
kwargs: dict[str, Any] = field(default_factory=dict)
|
|
36
|
+
name: str | None = None
|
|
37
|
+
|
|
38
|
+
def run(self) -> TaskHandle:
|
|
39
|
+
"""Submit the task to the owning scheduler and return its handle."""
|
|
40
|
+
pipeline = self.pipeline
|
|
41
|
+
scheduler, _ = pipeline._require_control_context()
|
|
42
|
+
with scheduler._condition:
|
|
43
|
+
return scheduler._submit_task_builder_locked(self)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class Pipeline:
|
|
47
|
+
"""Base class for user-defined pipelines."""
|
|
48
|
+
|
|
49
|
+
_stagegate_record: PipelineRecord | None = None
|
|
50
|
+
_stagegate_scheduler: Scheduler | None = None
|
|
51
|
+
_stagegate_submitted: bool = False
|
|
52
|
+
|
|
53
|
+
def run(self) -> Any:
|
|
54
|
+
"""Execute pipeline logic on a scheduler-owned coordinator thread."""
|
|
55
|
+
raise NotImplementedError
|
|
56
|
+
|
|
57
|
+
def _require_control_context(self):
|
|
58
|
+
record = getattr(self, "_stagegate_record", None)
|
|
59
|
+
scheduler = getattr(self, "_stagegate_scheduler", None)
|
|
60
|
+
if record is None or scheduler is None:
|
|
61
|
+
raise RuntimeError("pipeline control requires a running pipeline")
|
|
62
|
+
if record.state is not PipelineState.RUNNING:
|
|
63
|
+
raise RuntimeError("pipeline control requires a running pipeline")
|
|
64
|
+
if record.coordinator_thread_ident != threading.get_ident():
|
|
65
|
+
raise RuntimeError(
|
|
66
|
+
"pipeline control is allowed only on the coordinator thread"
|
|
67
|
+
)
|
|
68
|
+
return scheduler, record
|
|
69
|
+
|
|
70
|
+
def task(
|
|
71
|
+
self,
|
|
72
|
+
fn: Callable[..., Any],
|
|
73
|
+
*,
|
|
74
|
+
resources: dict[str, int | float],
|
|
75
|
+
args: tuple[Any, ...] = (),
|
|
76
|
+
kwargs: dict[str, Any] | None = None,
|
|
77
|
+
name: str | None = None,
|
|
78
|
+
) -> TaskBuilder:
|
|
79
|
+
"""Create a task builder for later submission via ``.run()``."""
|
|
80
|
+
return TaskBuilder(
|
|
81
|
+
pipeline=self,
|
|
82
|
+
fn=fn,
|
|
83
|
+
resources=dict(resources),
|
|
84
|
+
args=args,
|
|
85
|
+
kwargs={} if kwargs is None else dict(kwargs),
|
|
86
|
+
name=name,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
def stage_forward(self) -> None:
|
|
90
|
+
"""Advance the pipeline stage used by future task submissions."""
|
|
91
|
+
scheduler, record = self._require_control_context()
|
|
92
|
+
with scheduler._condition:
|
|
93
|
+
record.stage_index += 1
|
|
94
|
+
|
|
95
|
+
def wait(
|
|
96
|
+
self,
|
|
97
|
+
handles: Iterable[TaskHandle],
|
|
98
|
+
timeout: float | None = None,
|
|
99
|
+
return_when: str = ALL_COMPLETED,
|
|
100
|
+
) -> tuple[set[TaskHandle], set[TaskHandle]]:
|
|
101
|
+
"""Wait for task handles created by this pipeline."""
|
|
102
|
+
# Concrete implementation must validate return_when against WAIT_CONDITIONS.
|
|
103
|
+
scheduler, _ = self._require_control_context()
|
|
104
|
+
normalized = validate_wait_request(
|
|
105
|
+
handles,
|
|
106
|
+
expected_type=TaskHandle,
|
|
107
|
+
owner_check=lambda handle: handle._record.pipeline_record.pipeline is self,
|
|
108
|
+
timeout=timeout,
|
|
109
|
+
return_when=return_when,
|
|
110
|
+
)
|
|
111
|
+
deadline = monotonic_deadline(timeout)
|
|
112
|
+
|
|
113
|
+
with scheduler._condition:
|
|
114
|
+
while True:
|
|
115
|
+
done, pending = split_done_pending(normalized)
|
|
116
|
+
if should_return(done=done, pending=pending, return_when=return_when):
|
|
117
|
+
return done, pending
|
|
118
|
+
|
|
119
|
+
wait_timeout = remaining_timeout(deadline)
|
|
120
|
+
if wait_timeout == 0.0:
|
|
121
|
+
return done, pending
|
|
122
|
+
scheduler._condition.wait(wait_timeout)
|