stagegate 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
stagegate/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ """Public package namespace for stage-aware local pipeline execution."""
2
+
3
+ from .exceptions import CancelledError, UnknownResourceError, UnschedulableTaskError
4
+ from .handles import PipelineHandle, TaskHandle
5
+ from .pipeline import Pipeline
6
+ from .scheduler import Scheduler
7
+ from .wait import ALL_COMPLETED, FIRST_COMPLETED, FIRST_EXCEPTION
8
+
9
+ __all__ = [
10
+ "Scheduler",
11
+ "Pipeline",
12
+ "TaskHandle",
13
+ "PipelineHandle",
14
+ "FIRST_COMPLETED",
15
+ "FIRST_EXCEPTION",
16
+ "ALL_COMPLETED",
17
+ "CancelledError",
18
+ "UnknownResourceError",
19
+ "UnschedulableTaskError",
20
+ ]
stagegate/_records.py ADDED
@@ -0,0 +1,150 @@
1
+ """Internal scheduler record and queue-entry definitions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import deque
6
+ from collections.abc import Callable
7
+ from dataclasses import dataclass, field
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ from ._states import PipelineState, SchedulerState, TaskState
11
+
12
+ if TYPE_CHECKING:
13
+ from .pipeline import Pipeline
14
+ from .scheduler import Scheduler
15
+
16
+
17
+ TerminalTaskState = frozenset(
18
+ {
19
+ TaskState.SUCCEEDED,
20
+ TaskState.FAILED,
21
+ TaskState.CANCELLED,
22
+ }
23
+ )
24
+
25
+ TerminalPipelineState = frozenset(
26
+ {
27
+ PipelineState.SUCCEEDED,
28
+ PipelineState.FAILED,
29
+ PipelineState.CANCELLED,
30
+ }
31
+ )
32
+
33
+
34
+ @dataclass(order=True, frozen=True, slots=True)
35
+ class PipelineQueueEntry:
36
+ """FIFO queue entry for pipeline admission."""
37
+
38
+ enqueue_seq: int
39
+ record: PipelineRecord = field(compare=False)
40
+
41
+
42
+ @dataclass(order=True, frozen=True, slots=True)
43
+ class TaskPriorityKey:
44
+ """Deterministic ordering key for queued tasks."""
45
+
46
+ neg_stage: int
47
+ pipeline_enqueue_seq: int
48
+ pipeline_local_task_seq: int
49
+ global_task_submit_seq: int
50
+
51
+
52
+ @dataclass(order=True, frozen=True, slots=True)
53
+ class TaskQueueEntry:
54
+ """Priority-queue entry for task admission control."""
55
+
56
+ priority: TaskPriorityKey
57
+ record: TaskRecord = field(compare=False)
58
+
59
+ @classmethod
60
+ def from_record(cls, record: TaskRecord) -> TaskQueueEntry:
61
+ """Build a queue entry from the record's canonical priority key."""
62
+
63
+ return cls(priority=record.priority_key(), record=record)
64
+
65
+
66
+ @dataclass(order=True, frozen=True, slots=True)
67
+ class ReadyQueueEntry:
68
+ """FIFO queue entry for tasks admitted but not yet started by a worker."""
69
+
70
+ ready_seq: int
71
+ record: TaskRecord = field(compare=False)
72
+
73
+
74
+ @dataclass(slots=True)
75
+ class PipelineRecord:
76
+ """Mutable source-of-truth record for one pipeline instance."""
77
+
78
+ scheduler: Scheduler
79
+ pipeline: Pipeline
80
+ pipeline_id: int
81
+ enqueue_seq: int
82
+ state: PipelineState = PipelineState.QUEUED
83
+ stage_index: int = 0
84
+ next_task_seq: int = 0
85
+ result_value: Any = None
86
+ exception: BaseException | None = None
87
+ coordinator_thread_ident: int | None = None
88
+ task_records: list[TaskRecord] = field(default_factory=list)
89
+
90
+ def is_terminal(self) -> bool:
91
+ """Return whether the pipeline is in a terminal state."""
92
+
93
+ return self.state in TerminalPipelineState
94
+
95
+
96
+ @dataclass(slots=True)
97
+ class TaskRecord:
98
+ """Mutable source-of-truth record for one scheduled task."""
99
+
100
+ scheduler: Scheduler
101
+ pipeline_record: PipelineRecord
102
+ task_id: int
103
+ fn: Callable[..., Any]
104
+ resources_required: dict[str, int | float]
105
+ args: tuple[Any, ...] = ()
106
+ kwargs: dict[str, Any] = field(default_factory=dict)
107
+ name: str | None = None
108
+ stage_snapshot: int = 0
109
+ pipeline_enqueue_seq: int = 0
110
+ pipeline_local_task_seq: int = 0
111
+ global_task_submit_seq: int = 0
112
+ state: TaskState = TaskState.QUEUED
113
+ result_value: Any = None
114
+ exception: BaseException | None = None
115
+ ready_seq: int | None = None
116
+ worker_thread_ident: int | None = None
117
+
118
+ def priority_key(self) -> TaskPriorityKey:
119
+ """Build the ordering key used by the task priority queue."""
120
+
121
+ return TaskPriorityKey(
122
+ neg_stage=-self.stage_snapshot,
123
+ pipeline_enqueue_seq=self.pipeline_enqueue_seq,
124
+ pipeline_local_task_seq=self.pipeline_local_task_seq,
125
+ global_task_submit_seq=self.global_task_submit_seq,
126
+ )
127
+
128
+ def is_terminal(self) -> bool:
129
+ """Return whether the task is in a terminal state."""
130
+
131
+ return self.state in TerminalTaskState
132
+
133
+
134
+ @dataclass(slots=True)
135
+ class SchedulerRuntime:
136
+ """Mutable internal scheduler-wide runtime bookkeeping."""
137
+
138
+ state: SchedulerState = SchedulerState.OPEN
139
+ next_pipeline_id: int = 0
140
+ next_pipeline_enqueue_seq: int = 0
141
+ next_task_id: int = 0
142
+ next_global_task_submit_seq: int = 0
143
+ next_ready_seq: int = 0
144
+ admitted_task_count: int = 0
145
+ pipeline_queue: deque[PipelineQueueEntry] = field(default_factory=deque)
146
+ task_queue: list[TaskQueueEntry] = field(default_factory=list)
147
+ ready_queue: deque[ReadyQueueEntry] = field(default_factory=deque)
148
+ resources_in_use: dict[str, int | float] = field(default_factory=dict)
149
+ pipeline_records: dict[int, PipelineRecord] = field(default_factory=dict)
150
+ task_records: dict[int, TaskRecord] = field(default_factory=dict)
stagegate/_states.py ADDED
@@ -0,0 +1,34 @@
1
+ """Internal state definitions for scheduler records."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from enum import Enum
6
+
7
+
8
+ class SchedulerState(Enum):
9
+ """Internal lifecycle for scheduler shutdown progression."""
10
+
11
+ OPEN = "open"
12
+ SHUTTING_DOWN = "shutting_down"
13
+ CLOSED = "closed"
14
+
15
+
16
+ class PipelineState(Enum):
17
+ """Internal pipeline execution states."""
18
+
19
+ QUEUED = "queued"
20
+ RUNNING = "running"
21
+ SUCCEEDED = "succeeded"
22
+ FAILED = "failed"
23
+ CANCELLED = "cancelled"
24
+
25
+
26
+ class TaskState(Enum):
27
+ """Internal task execution states."""
28
+
29
+ QUEUED = "queued"
30
+ READY = "ready"
31
+ RUNNING = "running"
32
+ SUCCEEDED = "succeeded"
33
+ FAILED = "failed"
34
+ CANCELLED = "cancelled"
@@ -0,0 +1,97 @@
1
+ """Internal helpers for owner-scoped wait APIs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from collections.abc import Callable, Iterable
7
+ from typing import TypeVar
8
+
9
+ from ._states import PipelineState, TaskState
10
+ from .handles import PipelineHandle, TaskHandle
11
+ from .wait import ALL_COMPLETED, FIRST_COMPLETED, WAIT_CONDITIONS
12
+
13
+ HandleT = TypeVar("HandleT", TaskHandle, PipelineHandle)
14
+
15
+
16
+ def validate_wait_request(
17
+ handles: Iterable[HandleT],
18
+ *,
19
+ expected_type: type[HandleT],
20
+ owner_check: Callable[[HandleT], bool],
21
+ timeout: float | None,
22
+ return_when: str,
23
+ ) -> set[HandleT]:
24
+ """Normalize and validate public wait() inputs."""
25
+
26
+ if timeout is not None and timeout < 0:
27
+ raise ValueError("timeout must be None or a non-negative number")
28
+ if return_when not in WAIT_CONDITIONS:
29
+ raise ValueError("invalid return_when")
30
+
31
+ candidates = tuple(handles)
32
+ if not candidates:
33
+ raise ValueError("handles must not be empty")
34
+
35
+ for handle in candidates:
36
+ if not isinstance(handle, expected_type):
37
+ raise TypeError("wrong handle type")
38
+ if not owner_check(handle):
39
+ raise ValueError("handle does not belong to this owner")
40
+
41
+ normalized = set(candidates)
42
+ return normalized
43
+
44
+
45
+ def split_done_pending(handles: set[HandleT]) -> tuple[set[HandleT], set[HandleT]]:
46
+ """Partition handles into terminal and non-terminal subsets."""
47
+
48
+ done: set[HandleT] = set()
49
+ pending: set[HandleT] = set()
50
+ for handle in handles:
51
+ if handle.done():
52
+ done.add(handle)
53
+ else:
54
+ pending.add(handle)
55
+ return done, pending
56
+
57
+
58
+ def should_return(
59
+ *,
60
+ done: set[HandleT],
61
+ pending: set[HandleT],
62
+ return_when: str,
63
+ ) -> bool:
64
+ """Evaluate the public wait-return condition."""
65
+
66
+ if return_when == ALL_COMPLETED:
67
+ return not pending
68
+ if return_when == FIRST_COMPLETED:
69
+ return bool(done)
70
+ if any(_is_failed(handle) for handle in done):
71
+ return True
72
+ return not pending
73
+
74
+
75
+ def monotonic_deadline(timeout: float | None) -> float | None:
76
+ """Convert a public timeout value into an absolute deadline."""
77
+
78
+ if timeout is None:
79
+ return None
80
+ return time.monotonic() + timeout
81
+
82
+
83
+ def remaining_timeout(deadline: float | None) -> float | None:
84
+ """Compute remaining timeout for a condition wait."""
85
+
86
+ if deadline is None:
87
+ return None
88
+ remaining = deadline - time.monotonic()
89
+ if remaining <= 0:
90
+ return 0.0
91
+ return remaining
92
+
93
+
94
+ def _is_failed(handle: HandleT) -> bool:
95
+ if isinstance(handle, TaskHandle):
96
+ return handle._record.state is TaskState.FAILED
97
+ return handle._record.state is PipelineState.FAILED
@@ -0,0 +1,13 @@
1
+ """Public exceptions for stagegate."""
2
+
3
+
4
+ class CancelledError(Exception):
5
+ """Raised when result() or exception() is requested from a cancelled handle."""
6
+
7
+
8
+ class UnknownResourceError(ValueError):
9
+ """Raised when a task requests a resource label unknown to the scheduler."""
10
+
11
+
12
+ class UnschedulableTaskError(ValueError):
13
+ """Raised when a single task can never fit within scheduler capacity."""
stagegate/handles.py ADDED
@@ -0,0 +1,216 @@
1
+ """Public handle skeletons."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ from ._records import TerminalPipelineState, TerminalTaskState
9
+ from ._states import PipelineState, TaskState
10
+ from .exceptions import CancelledError
11
+
12
+ if TYPE_CHECKING:
13
+ from ._records import PipelineRecord, TaskRecord
14
+
15
+
16
+ def _validate_timeout(timeout: float | None) -> float | None:
17
+ if timeout is None:
18
+ return None
19
+ if timeout < 0:
20
+ raise ValueError("timeout must be None or a non-negative number")
21
+ return timeout
22
+
23
+
24
+ def _remaining_timeout(deadline: float) -> float:
25
+ remaining = deadline - time.monotonic()
26
+ if remaining <= 0:
27
+ raise TimeoutError
28
+ return remaining
29
+
30
+
31
+ class TaskHandle:
32
+ """Handle for a task submitted from a pipeline."""
33
+
34
+ __slots__ = ("_record",)
35
+
36
+ def __init__(self, record: TaskRecord) -> None:
37
+ self._record = record
38
+
39
+ def __eq__(self, other: object) -> bool:
40
+ if not isinstance(other, TaskHandle):
41
+ return NotImplemented
42
+ return (
43
+ self._record.task_id == other._record.task_id
44
+ and self._record.scheduler is other._record.scheduler
45
+ )
46
+
47
+ def __hash__(self) -> int:
48
+ return hash((TaskHandle, id(self._record.scheduler), self._record.task_id))
49
+
50
+ def __repr__(self) -> str:
51
+ return (
52
+ f"TaskHandle(scheduler=0x{id(self._record.scheduler):x}, "
53
+ f"task_id={self._record.task_id})"
54
+ )
55
+
56
+ def cancel(self) -> bool:
57
+ """Cancel the task if it has not started yet."""
58
+ scheduler = self._record.scheduler
59
+ with scheduler._condition:
60
+ return scheduler._cancel_task_if_possible_locked(self._record)
61
+
62
+ def done(self) -> bool:
63
+ """Return whether the task is terminal."""
64
+ scheduler = self._record.scheduler
65
+ with scheduler._condition:
66
+ return self._record.state in TerminalTaskState
67
+
68
+ def running(self) -> bool:
69
+ """Return whether the task callable is actively running."""
70
+ scheduler = self._record.scheduler
71
+ with scheduler._condition:
72
+ return self._record.state is TaskState.RUNNING
73
+
74
+ def cancelled(self) -> bool:
75
+ """Return whether the task ended in the cancelled state."""
76
+ scheduler = self._record.scheduler
77
+ with scheduler._condition:
78
+ return self._record.state is TaskState.CANCELLED
79
+
80
+ def result(self, timeout: float | None = None) -> Any:
81
+ """Return the task result or raise its stored terminal outcome."""
82
+ timeout = _validate_timeout(timeout)
83
+ scheduler = self._record.scheduler
84
+ deadline = None if timeout is None else time.monotonic() + timeout
85
+ with scheduler._condition:
86
+ while self._record.state not in TerminalTaskState:
87
+ wait_timeout = (
88
+ None if deadline is None else _remaining_timeout(deadline)
89
+ )
90
+ scheduler._condition.wait(wait_timeout)
91
+
92
+ if self._record.state is TaskState.SUCCEEDED:
93
+ return self._record.result_value
94
+ if self._record.state is TaskState.FAILED:
95
+ assert self._record.exception is not None
96
+ raise self._record.exception
97
+ raise CancelledError("result() was requested from a cancelled task handle")
98
+
99
+ def exception(self, timeout: float | None = None) -> BaseException | None:
100
+ """Return the task exception object, if any."""
101
+ timeout = _validate_timeout(timeout)
102
+ scheduler = self._record.scheduler
103
+ deadline = None if timeout is None else time.monotonic() + timeout
104
+ with scheduler._condition:
105
+ while self._record.state not in TerminalTaskState:
106
+ wait_timeout = (
107
+ None if deadline is None else _remaining_timeout(deadline)
108
+ )
109
+ scheduler._condition.wait(wait_timeout)
110
+
111
+ if self._record.state is TaskState.SUCCEEDED:
112
+ return None
113
+ if self._record.state is TaskState.FAILED:
114
+ assert self._record.exception is not None
115
+ return self._record.exception
116
+ raise CancelledError(
117
+ "exception() was requested from a cancelled task handle"
118
+ )
119
+
120
+
121
+ class PipelineHandle:
122
+ """Handle for a pipeline submitted to a scheduler."""
123
+
124
+ __slots__ = ("_record",)
125
+
126
+ def __init__(self, record: PipelineRecord) -> None:
127
+ self._record = record
128
+
129
+ def __eq__(self, other: object) -> bool:
130
+ if not isinstance(other, PipelineHandle):
131
+ return NotImplemented
132
+ return (
133
+ self._record.pipeline_id == other._record.pipeline_id
134
+ and self._record.scheduler is other._record.scheduler
135
+ )
136
+
137
+ def __hash__(self) -> int:
138
+ return hash(
139
+ (PipelineHandle, id(self._record.scheduler), self._record.pipeline_id)
140
+ )
141
+
142
+ def __repr__(self) -> str:
143
+ return (
144
+ f"PipelineHandle(scheduler=0x{id(self._record.scheduler):x}, "
145
+ f"pipeline_id={self._record.pipeline_id})"
146
+ )
147
+
148
+ def cancel(self) -> bool:
149
+ """Cancel the pipeline if it has not started yet."""
150
+ scheduler = self._record.scheduler
151
+ with scheduler._condition:
152
+ if self._record.state is PipelineState.QUEUED:
153
+ self._record.state = PipelineState.CANCELLED
154
+ scheduler._notify_state_change_locked()
155
+ return True
156
+ return False
157
+
158
+ def done(self) -> bool:
159
+ """Return whether the pipeline is terminal."""
160
+ scheduler = self._record.scheduler
161
+ with scheduler._condition:
162
+ return self._record.state in TerminalPipelineState
163
+
164
+ def running(self) -> bool:
165
+ """Return whether the pipeline ``run()`` method is active."""
166
+ scheduler = self._record.scheduler
167
+ with scheduler._condition:
168
+ return self._record.state is PipelineState.RUNNING
169
+
170
+ def cancelled(self) -> bool:
171
+ """Return whether the pipeline ended in the cancelled state."""
172
+ scheduler = self._record.scheduler
173
+ with scheduler._condition:
174
+ return self._record.state is PipelineState.CANCELLED
175
+
176
+ def result(self, timeout: float | None = None) -> Any:
177
+ """Return the pipeline result or raise its stored terminal outcome."""
178
+ timeout = _validate_timeout(timeout)
179
+ scheduler = self._record.scheduler
180
+ deadline = None if timeout is None else time.monotonic() + timeout
181
+ with scheduler._condition:
182
+ while self._record.state not in TerminalPipelineState:
183
+ wait_timeout = (
184
+ None if deadline is None else _remaining_timeout(deadline)
185
+ )
186
+ scheduler._condition.wait(wait_timeout)
187
+
188
+ if self._record.state is PipelineState.SUCCEEDED:
189
+ return self._record.result_value
190
+ if self._record.state is PipelineState.FAILED:
191
+ assert self._record.exception is not None
192
+ raise self._record.exception
193
+ raise CancelledError(
194
+ "result() was requested from a cancelled pipeline handle"
195
+ )
196
+
197
+ def exception(self, timeout: float | None = None) -> BaseException | None:
198
+ """Return the pipeline exception object, if any."""
199
+ timeout = _validate_timeout(timeout)
200
+ scheduler = self._record.scheduler
201
+ deadline = None if timeout is None else time.monotonic() + timeout
202
+ with scheduler._condition:
203
+ while self._record.state not in TerminalPipelineState:
204
+ wait_timeout = (
205
+ None if deadline is None else _remaining_timeout(deadline)
206
+ )
207
+ scheduler._condition.wait(wait_timeout)
208
+
209
+ if self._record.state is PipelineState.SUCCEEDED:
210
+ return None
211
+ if self._record.state is PipelineState.FAILED:
212
+ assert self._record.exception is not None
213
+ return self._record.exception
214
+ raise CancelledError(
215
+ "exception() was requested from a cancelled pipeline handle"
216
+ )
stagegate/pipeline.py ADDED
@@ -0,0 +1,122 @@
1
+ """Pipeline base class and task-builder skeleton."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ import threading
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ from .handles import TaskHandle
10
+ from ._states import PipelineState
11
+ from ._wait_utils import (
12
+ monotonic_deadline,
13
+ remaining_timeout,
14
+ should_return,
15
+ split_done_pending,
16
+ validate_wait_request,
17
+ )
18
+ from .wait import ALL_COMPLETED
19
+
20
+ if TYPE_CHECKING:
21
+ from collections.abc import Callable, Iterable
22
+
23
+ from ._records import PipelineRecord
24
+ from .scheduler import Scheduler
25
+
26
+
27
+ @dataclass(frozen=True, slots=True)
28
+ class TaskBuilder:
29
+ """Factory object returned by ``Pipeline.task(...)``."""
30
+
31
+ pipeline: Pipeline
32
+ fn: Callable[..., Any]
33
+ resources: dict[str, int | float]
34
+ args: tuple[Any, ...] = ()
35
+ kwargs: dict[str, Any] = field(default_factory=dict)
36
+ name: str | None = None
37
+
38
+ def run(self) -> TaskHandle:
39
+ """Submit the task to the owning scheduler and return its handle."""
40
+ pipeline = self.pipeline
41
+ scheduler, _ = pipeline._require_control_context()
42
+ with scheduler._condition:
43
+ return scheduler._submit_task_builder_locked(self)
44
+
45
+
46
+ class Pipeline:
47
+ """Base class for user-defined pipelines."""
48
+
49
+ _stagegate_record: PipelineRecord | None = None
50
+ _stagegate_scheduler: Scheduler | None = None
51
+ _stagegate_submitted: bool = False
52
+
53
+ def run(self) -> Any:
54
+ """Execute pipeline logic on a scheduler-owned coordinator thread."""
55
+ raise NotImplementedError
56
+
57
+ def _require_control_context(self):
58
+ record = getattr(self, "_stagegate_record", None)
59
+ scheduler = getattr(self, "_stagegate_scheduler", None)
60
+ if record is None or scheduler is None:
61
+ raise RuntimeError("pipeline control requires a running pipeline")
62
+ if record.state is not PipelineState.RUNNING:
63
+ raise RuntimeError("pipeline control requires a running pipeline")
64
+ if record.coordinator_thread_ident != threading.get_ident():
65
+ raise RuntimeError(
66
+ "pipeline control is allowed only on the coordinator thread"
67
+ )
68
+ return scheduler, record
69
+
70
+ def task(
71
+ self,
72
+ fn: Callable[..., Any],
73
+ *,
74
+ resources: dict[str, int | float],
75
+ args: tuple[Any, ...] = (),
76
+ kwargs: dict[str, Any] | None = None,
77
+ name: str | None = None,
78
+ ) -> TaskBuilder:
79
+ """Create a task builder for later submission via ``.run()``."""
80
+ return TaskBuilder(
81
+ pipeline=self,
82
+ fn=fn,
83
+ resources=dict(resources),
84
+ args=args,
85
+ kwargs={} if kwargs is None else dict(kwargs),
86
+ name=name,
87
+ )
88
+
89
+ def stage_forward(self) -> None:
90
+ """Advance the pipeline stage used by future task submissions."""
91
+ scheduler, record = self._require_control_context()
92
+ with scheduler._condition:
93
+ record.stage_index += 1
94
+
95
+ def wait(
96
+ self,
97
+ handles: Iterable[TaskHandle],
98
+ timeout: float | None = None,
99
+ return_when: str = ALL_COMPLETED,
100
+ ) -> tuple[set[TaskHandle], set[TaskHandle]]:
101
+ """Wait for task handles created by this pipeline."""
102
+ # Concrete implementation must validate return_when against WAIT_CONDITIONS.
103
+ scheduler, _ = self._require_control_context()
104
+ normalized = validate_wait_request(
105
+ handles,
106
+ expected_type=TaskHandle,
107
+ owner_check=lambda handle: handle._record.pipeline_record.pipeline is self,
108
+ timeout=timeout,
109
+ return_when=return_when,
110
+ )
111
+ deadline = monotonic_deadline(timeout)
112
+
113
+ with scheduler._condition:
114
+ while True:
115
+ done, pending = split_done_pending(normalized)
116
+ if should_return(done=done, pending=pending, return_when=return_when):
117
+ return done, pending
118
+
119
+ wait_timeout = remaining_timeout(deadline)
120
+ if wait_timeout == 0.0:
121
+ return done, pending
122
+ scheduler._condition.wait(wait_timeout)