stabilize 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stabilize/__init__.py +29 -0
- stabilize/cli.py +1193 -0
- stabilize/context/__init__.py +7 -0
- stabilize/context/stage_context.py +170 -0
- stabilize/dag/__init__.py +15 -0
- stabilize/dag/graph.py +215 -0
- stabilize/dag/topological.py +199 -0
- stabilize/examples/__init__.py +1 -0
- stabilize/examples/docker-example.py +759 -0
- stabilize/examples/golden-standard-expected-result.txt +1 -0
- stabilize/examples/golden-standard.py +488 -0
- stabilize/examples/http-example.py +606 -0
- stabilize/examples/llama-example.py +662 -0
- stabilize/examples/python-example.py +731 -0
- stabilize/examples/shell-example.py +399 -0
- stabilize/examples/ssh-example.py +603 -0
- stabilize/handlers/__init__.py +53 -0
- stabilize/handlers/base.py +226 -0
- stabilize/handlers/complete_stage.py +209 -0
- stabilize/handlers/complete_task.py +75 -0
- stabilize/handlers/complete_workflow.py +150 -0
- stabilize/handlers/run_task.py +369 -0
- stabilize/handlers/start_stage.py +262 -0
- stabilize/handlers/start_task.py +74 -0
- stabilize/handlers/start_workflow.py +136 -0
- stabilize/launcher.py +307 -0
- stabilize/migrations/01KDQ4N9QPJ6Q4MCV3V9GHWPV4_initial_schema.sql +97 -0
- stabilize/migrations/01KDRK3TXW4R2GERC1WBCQYJGG_rag_embeddings.sql +25 -0
- stabilize/migrations/__init__.py +1 -0
- stabilize/models/__init__.py +15 -0
- stabilize/models/stage.py +389 -0
- stabilize/models/status.py +146 -0
- stabilize/models/task.py +125 -0
- stabilize/models/workflow.py +317 -0
- stabilize/orchestrator.py +113 -0
- stabilize/persistence/__init__.py +28 -0
- stabilize/persistence/connection.py +185 -0
- stabilize/persistence/factory.py +136 -0
- stabilize/persistence/memory.py +214 -0
- stabilize/persistence/postgres.py +655 -0
- stabilize/persistence/sqlite.py +674 -0
- stabilize/persistence/store.py +235 -0
- stabilize/queue/__init__.py +59 -0
- stabilize/queue/messages.py +377 -0
- stabilize/queue/processor.py +312 -0
- stabilize/queue/queue.py +526 -0
- stabilize/queue/sqlite_queue.py +354 -0
- stabilize/rag/__init__.py +19 -0
- stabilize/rag/assistant.py +459 -0
- stabilize/rag/cache.py +294 -0
- stabilize/stages/__init__.py +11 -0
- stabilize/stages/builder.py +253 -0
- stabilize/tasks/__init__.py +19 -0
- stabilize/tasks/interface.py +335 -0
- stabilize/tasks/registry.py +255 -0
- stabilize/tasks/result.py +283 -0
- stabilize-0.9.2.dist-info/METADATA +301 -0
- stabilize-0.9.2.dist-info/RECORD +61 -0
- stabilize-0.9.2.dist-info/WHEEL +4 -0
- stabilize-0.9.2.dist-info/entry_points.txt +2 -0
- stabilize-0.9.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task interface definitions.
|
|
3
|
+
|
|
4
|
+
This module defines the Task interface and its variants (RetryableTask,
|
|
5
|
+
SkippableTask) that all task implementations must follow.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
from abc import ABC, abstractmethod
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
from datetime import timedelta
|
|
13
|
+
from typing import TYPE_CHECKING
|
|
14
|
+
|
|
15
|
+
from stabilize.tasks.result import TaskResult
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from stabilize.models.stage import StageExecution
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Task(ABC):
|
|
22
|
+
"""
|
|
23
|
+
Base interface for all tasks.
|
|
24
|
+
|
|
25
|
+
Tasks are the atomic units of work in a pipeline. Each task:
|
|
26
|
+
- Receives the current stage context
|
|
27
|
+
- Performs some work
|
|
28
|
+
- Returns a TaskResult indicating status and any outputs
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
class DeployTask(Task):
|
|
32
|
+
def execute(self, stage: StageExecution) -> TaskResult:
|
|
33
|
+
# Get inputs from context
|
|
34
|
+
cluster = stage.context.get("cluster")
|
|
35
|
+
image = stage.context.get("image")
|
|
36
|
+
|
|
37
|
+
# Do the work
|
|
38
|
+
deployment_id = deploy(cluster, image)
|
|
39
|
+
|
|
40
|
+
# Return result with outputs
|
|
41
|
+
return TaskResult.success(
|
|
42
|
+
outputs={"deploymentId": deployment_id}
|
|
43
|
+
)
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def execute(self, stage: StageExecution) -> TaskResult:
|
|
48
|
+
"""
|
|
49
|
+
Execute the task.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
stage: The stage execution context
|
|
53
|
+
|
|
54
|
+
Returns:
|
|
55
|
+
TaskResult indicating status and any outputs
|
|
56
|
+
|
|
57
|
+
Raises:
|
|
58
|
+
Exception: Any exception will be caught and handled by the runner
|
|
59
|
+
"""
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
def on_timeout(self, stage: StageExecution) -> TaskResult | None:
|
|
63
|
+
"""
|
|
64
|
+
Called when the task times out.
|
|
65
|
+
|
|
66
|
+
Override to provide custom timeout handling. If None is returned,
|
|
67
|
+
the default timeout behavior applies.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
stage: The stage execution context
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Optional TaskResult to use instead of default timeout
|
|
74
|
+
"""
|
|
75
|
+
return None
|
|
76
|
+
|
|
77
|
+
def on_cancel(self, stage: StageExecution) -> TaskResult | None:
|
|
78
|
+
"""
|
|
79
|
+
Called when the execution is canceled.
|
|
80
|
+
|
|
81
|
+
Override to provide cleanup logic when execution is canceled.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
stage: The stage execution context
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
Optional TaskResult with cleanup results
|
|
88
|
+
"""
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def aliases(self) -> list[str]:
|
|
93
|
+
"""
|
|
94
|
+
Alternative names for this task type.
|
|
95
|
+
|
|
96
|
+
Used for backward compatibility when task types are renamed.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
List of alternative type names
|
|
100
|
+
"""
|
|
101
|
+
return []
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class RetryableTask(Task):
|
|
105
|
+
"""
|
|
106
|
+
A task that can be retried with timeout and backoff.
|
|
107
|
+
|
|
108
|
+
Retryable tasks return RUNNING status while waiting for some condition.
|
|
109
|
+
They are re-executed after a backoff period until they succeed, fail,
|
|
110
|
+
or timeout.
|
|
111
|
+
|
|
112
|
+
Example:
|
|
113
|
+
class WaitForDeployTask(RetryableTask):
|
|
114
|
+
def get_timeout(self) -> timedelta:
|
|
115
|
+
return timedelta(minutes=30)
|
|
116
|
+
|
|
117
|
+
def get_backoff_period(self, stage, duration) -> timedelta:
|
|
118
|
+
return timedelta(seconds=10)
|
|
119
|
+
|
|
120
|
+
def execute(self, stage: StageExecution) -> TaskResult:
|
|
121
|
+
deployment_id = stage.context.get("deploymentId")
|
|
122
|
+
status = check_deployment_status(deployment_id)
|
|
123
|
+
|
|
124
|
+
if status == "complete":
|
|
125
|
+
return TaskResult.success()
|
|
126
|
+
elif status == "failed":
|
|
127
|
+
return TaskResult.terminal("Deployment failed")
|
|
128
|
+
else:
|
|
129
|
+
return TaskResult.running()
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
@abstractmethod
|
|
133
|
+
def get_timeout(self) -> timedelta:
|
|
134
|
+
"""
|
|
135
|
+
Get the maximum time this task can run before timing out.
|
|
136
|
+
|
|
137
|
+
Returns:
|
|
138
|
+
Maximum execution time
|
|
139
|
+
"""
|
|
140
|
+
pass
|
|
141
|
+
|
|
142
|
+
def get_backoff_period(
|
|
143
|
+
self,
|
|
144
|
+
stage: StageExecution,
|
|
145
|
+
duration: timedelta,
|
|
146
|
+
) -> timedelta:
|
|
147
|
+
"""
|
|
148
|
+
Get the backoff period before retrying.
|
|
149
|
+
|
|
150
|
+
Override to implement dynamic backoff based on how long
|
|
151
|
+
the task has been running.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
stage: The stage execution context
|
|
155
|
+
duration: How long the task has been running
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
Time to wait before retrying
|
|
159
|
+
"""
|
|
160
|
+
return timedelta(seconds=1)
|
|
161
|
+
|
|
162
|
+
def get_dynamic_timeout(self, stage: StageExecution) -> timedelta:
|
|
163
|
+
"""
|
|
164
|
+
Get dynamic timeout based on stage context.
|
|
165
|
+
|
|
166
|
+
Override to implement context-based timeouts.
|
|
167
|
+
|
|
168
|
+
Args:
|
|
169
|
+
stage: The stage execution context
|
|
170
|
+
|
|
171
|
+
Returns:
|
|
172
|
+
Timeout duration
|
|
173
|
+
"""
|
|
174
|
+
return self.get_timeout()
|
|
175
|
+
|
|
176
|
+
def get_dynamic_backoff_period(
|
|
177
|
+
self,
|
|
178
|
+
stage: StageExecution,
|
|
179
|
+
duration: timedelta,
|
|
180
|
+
) -> timedelta:
|
|
181
|
+
"""
|
|
182
|
+
Get dynamic backoff based on stage context.
|
|
183
|
+
|
|
184
|
+
Args:
|
|
185
|
+
stage: The stage execution context
|
|
186
|
+
duration: How long the task has been running
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Time to wait before retrying
|
|
190
|
+
"""
|
|
191
|
+
return self.get_backoff_period(stage, duration)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class OverridableTimeoutRetryableTask(RetryableTask):
|
|
195
|
+
"""
|
|
196
|
+
A retryable task whose timeout can be overridden by the stage.
|
|
197
|
+
|
|
198
|
+
The stage can set a 'stageTimeoutMs' context value to override
|
|
199
|
+
the default timeout.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def get_dynamic_timeout(self, stage: StageExecution) -> timedelta:
|
|
203
|
+
"""Get timeout, potentially overridden by stage context."""
|
|
204
|
+
if "stageTimeoutMs" in stage.context:
|
|
205
|
+
return timedelta(milliseconds=stage.context["stageTimeoutMs"])
|
|
206
|
+
return self.get_timeout()
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class SkippableTask(Task):
|
|
210
|
+
"""
|
|
211
|
+
A task that can be conditionally skipped.
|
|
212
|
+
|
|
213
|
+
Override is_enabled() to control when the task should be skipped.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def is_enabled(self, stage: StageExecution) -> bool:
|
|
217
|
+
"""
|
|
218
|
+
Check if this task is enabled.
|
|
219
|
+
|
|
220
|
+
Override to implement skip logic.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
stage: The stage execution context
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
True if task should execute, False to skip
|
|
227
|
+
"""
|
|
228
|
+
return True
|
|
229
|
+
|
|
230
|
+
def execute(self, stage: StageExecution) -> TaskResult:
|
|
231
|
+
"""Execute the task if enabled."""
|
|
232
|
+
if not self.is_enabled(stage):
|
|
233
|
+
return TaskResult.skipped()
|
|
234
|
+
return self.do_execute(stage)
|
|
235
|
+
|
|
236
|
+
@abstractmethod
|
|
237
|
+
def do_execute(self, stage: StageExecution) -> TaskResult:
|
|
238
|
+
"""
|
|
239
|
+
Perform the actual task execution.
|
|
240
|
+
|
|
241
|
+
Args:
|
|
242
|
+
stage: The stage execution context
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
TaskResult indicating status
|
|
246
|
+
"""
|
|
247
|
+
pass
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class CallableTask(Task):
|
|
251
|
+
"""
|
|
252
|
+
A task that wraps a callable function.
|
|
253
|
+
|
|
254
|
+
Allows using simple functions as tasks without creating a class.
|
|
255
|
+
|
|
256
|
+
Example:
|
|
257
|
+
def my_task(stage: StageExecution) -> TaskResult:
|
|
258
|
+
return TaskResult.success(outputs={"result": "done"})
|
|
259
|
+
|
|
260
|
+
task = CallableTask(my_task)
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
def __init__(
|
|
264
|
+
self,
|
|
265
|
+
func: Callable[[StageExecution], TaskResult],
|
|
266
|
+
name: str | None = None,
|
|
267
|
+
) -> None:
|
|
268
|
+
"""
|
|
269
|
+
Initialize with a callable.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
func: The function to call
|
|
273
|
+
name: Optional name for the task
|
|
274
|
+
"""
|
|
275
|
+
self._func = func
|
|
276
|
+
self._name = name or func.__name__
|
|
277
|
+
|
|
278
|
+
def execute(self, stage: StageExecution) -> TaskResult:
|
|
279
|
+
"""Execute the wrapped function."""
|
|
280
|
+
return self._func(stage)
|
|
281
|
+
|
|
282
|
+
@property
|
|
283
|
+
def name(self) -> str:
|
|
284
|
+
"""Get the task name."""
|
|
285
|
+
return self._name
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
class NoOpTask(Task):
|
|
289
|
+
"""
|
|
290
|
+
A task that does nothing.
|
|
291
|
+
|
|
292
|
+
Useful for testing or placeholder stages.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
def execute(self, stage: StageExecution) -> TaskResult:
|
|
296
|
+
"""Return success immediately."""
|
|
297
|
+
return TaskResult.success()
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
class WaitTask(RetryableTask):
|
|
301
|
+
"""
|
|
302
|
+
A task that waits for a specified duration.
|
|
303
|
+
|
|
304
|
+
Reads 'waitTime' from stage context (in seconds).
|
|
305
|
+
"""
|
|
306
|
+
|
|
307
|
+
def get_timeout(self) -> timedelta:
|
|
308
|
+
"""Wait tasks have a long timeout."""
|
|
309
|
+
return timedelta(hours=24)
|
|
310
|
+
|
|
311
|
+
def get_backoff_period(
|
|
312
|
+
self,
|
|
313
|
+
stage: StageExecution,
|
|
314
|
+
duration: timedelta,
|
|
315
|
+
) -> timedelta:
|
|
316
|
+
"""Check every second."""
|
|
317
|
+
return timedelta(seconds=1)
|
|
318
|
+
|
|
319
|
+
def execute(self, stage: StageExecution) -> TaskResult:
|
|
320
|
+
"""Wait for the specified time."""
|
|
321
|
+
import time
|
|
322
|
+
|
|
323
|
+
wait_time = stage.context.get("waitTime", 0)
|
|
324
|
+
start_time = stage.context.get("waitStartTime")
|
|
325
|
+
current_time = int(time.time())
|
|
326
|
+
|
|
327
|
+
if start_time is None:
|
|
328
|
+
# First execution - record start time
|
|
329
|
+
return TaskResult.running(context={"waitStartTime": current_time})
|
|
330
|
+
|
|
331
|
+
elapsed = current_time - start_time
|
|
332
|
+
if elapsed >= wait_time:
|
|
333
|
+
return TaskResult.success(outputs={"waitedSeconds": elapsed})
|
|
334
|
+
|
|
335
|
+
return TaskResult.running()
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Task registry for resolving task implementations.
|
|
3
|
+
|
|
4
|
+
This module provides the TaskRegistry class for registering and resolving
|
|
5
|
+
task implementations by name or type.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
|
|
13
|
+
from stabilize.tasks.interface import CallableTask, Task
|
|
14
|
+
from stabilize.tasks.result import TaskResult
|
|
15
|
+
|
|
16
|
+
if False: # TYPE_CHECKING
|
|
17
|
+
from stabilize.models.stage import StageExecution
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
# Type for task callable
|
|
22
|
+
TaskCallable = Callable[["StageExecution"], TaskResult]
|
|
23
|
+
|
|
24
|
+
# Type for task implementation - can be a Task class or callable
|
|
25
|
+
TaskImplementation = type[Task] | Task | TaskCallable
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class TaskNotFoundError(Exception):
|
|
29
|
+
"""Raised when a task type cannot be resolved."""
|
|
30
|
+
|
|
31
|
+
def __init__(self, task_type: str):
|
|
32
|
+
self.task_type = task_type
|
|
33
|
+
super().__init__(f"No task found for type: {task_type}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class TaskRegistry:
|
|
37
|
+
"""
|
|
38
|
+
Registry for task implementations.
|
|
39
|
+
|
|
40
|
+
Allows registering tasks by name and resolving them at runtime.
|
|
41
|
+
Supports:
|
|
42
|
+
- Task classes (instantiated on resolve)
|
|
43
|
+
- Task instances (used directly)
|
|
44
|
+
- Callable functions (wrapped in CallableTask)
|
|
45
|
+
|
|
46
|
+
Example:
|
|
47
|
+
registry = TaskRegistry()
|
|
48
|
+
|
|
49
|
+
# Register a task class
|
|
50
|
+
registry.register("deploy", DeployTask)
|
|
51
|
+
|
|
52
|
+
# Register a task instance
|
|
53
|
+
registry.register("notify", NotifyTask(slack_client))
|
|
54
|
+
|
|
55
|
+
# Register a function
|
|
56
|
+
@registry.task("validate")
|
|
57
|
+
def validate_inputs(stage):
|
|
58
|
+
return TaskResult.success()
|
|
59
|
+
|
|
60
|
+
# Resolve and use
|
|
61
|
+
task = registry.get("deploy")
|
|
62
|
+
result = task.execute(stage)
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self) -> None:
|
|
66
|
+
self._tasks: dict[str, TaskImplementation] = {}
|
|
67
|
+
self._aliases: dict[str, str] = {}
|
|
68
|
+
|
|
69
|
+
def register(
|
|
70
|
+
self,
|
|
71
|
+
name: str,
|
|
72
|
+
task: TaskImplementation,
|
|
73
|
+
aliases: list[str] | None = None,
|
|
74
|
+
) -> None:
|
|
75
|
+
"""
|
|
76
|
+
Register a task implementation.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
name: The task type name
|
|
80
|
+
task: Task class, instance, or callable
|
|
81
|
+
aliases: Optional alternative names
|
|
82
|
+
|
|
83
|
+
Raises:
|
|
84
|
+
ValueError: If name is already registered
|
|
85
|
+
"""
|
|
86
|
+
if name in self._tasks:
|
|
87
|
+
logger.warning(f"Overwriting existing task registration: {name}")
|
|
88
|
+
|
|
89
|
+
self._tasks[name] = task
|
|
90
|
+
|
|
91
|
+
# Register aliases
|
|
92
|
+
if aliases:
|
|
93
|
+
for alias in aliases:
|
|
94
|
+
self._aliases[alias] = name
|
|
95
|
+
|
|
96
|
+
# Check for aliases on the task itself
|
|
97
|
+
if isinstance(task, type) and issubclass(task, Task):
|
|
98
|
+
instance = task()
|
|
99
|
+
for alias in instance.aliases:
|
|
100
|
+
self._aliases[alias] = name
|
|
101
|
+
elif isinstance(task, Task):
|
|
102
|
+
for alias in task.aliases:
|
|
103
|
+
self._aliases[alias] = name
|
|
104
|
+
|
|
105
|
+
logger.debug(f"Registered task: {name}")
|
|
106
|
+
|
|
107
|
+
def register_class(
|
|
108
|
+
self,
|
|
109
|
+
task_class: type[Task],
|
|
110
|
+
name: str | None = None,
|
|
111
|
+
) -> None:
|
|
112
|
+
"""
|
|
113
|
+
Register a task class using its class name.
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
task_class: The task class to register
|
|
117
|
+
name: Optional name override
|
|
118
|
+
"""
|
|
119
|
+
task_name = name or task_class.__name__
|
|
120
|
+
self.register(task_name, task_class)
|
|
121
|
+
|
|
122
|
+
def task(
|
|
123
|
+
self,
|
|
124
|
+
name: str,
|
|
125
|
+
aliases: list[str] | None = None,
|
|
126
|
+
) -> Callable[[TaskCallable], TaskCallable]:
|
|
127
|
+
"""
|
|
128
|
+
Decorator to register a function as a task.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
name: The task type name
|
|
132
|
+
aliases: Optional alternative names
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Decorator function
|
|
136
|
+
|
|
137
|
+
Example:
|
|
138
|
+
@registry.task("validate")
|
|
139
|
+
def validate_inputs(stage):
|
|
140
|
+
return TaskResult.success()
|
|
141
|
+
"""
|
|
142
|
+
|
|
143
|
+
def decorator(func: TaskCallable) -> TaskCallable:
|
|
144
|
+
self.register(name, func, aliases)
|
|
145
|
+
return func
|
|
146
|
+
|
|
147
|
+
return decorator
|
|
148
|
+
|
|
149
|
+
def get(self, name: str) -> Task:
|
|
150
|
+
"""
|
|
151
|
+
Get a task implementation by name.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
name: The task type name
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
A Task instance
|
|
158
|
+
|
|
159
|
+
Raises:
|
|
160
|
+
TaskNotFoundError: If task not found
|
|
161
|
+
"""
|
|
162
|
+
# Check aliases first
|
|
163
|
+
resolved_name = self._aliases.get(name, name)
|
|
164
|
+
|
|
165
|
+
if resolved_name not in self._tasks:
|
|
166
|
+
raise TaskNotFoundError(resolved_name)
|
|
167
|
+
|
|
168
|
+
impl = self._tasks[resolved_name]
|
|
169
|
+
|
|
170
|
+
# Handle different registration types
|
|
171
|
+
if isinstance(impl, Task):
|
|
172
|
+
return impl
|
|
173
|
+
elif isinstance(impl, type) and issubclass(impl, Task):
|
|
174
|
+
return impl()
|
|
175
|
+
elif callable(impl):
|
|
176
|
+
# Cast to the proper type for CallableTask
|
|
177
|
+
from typing import cast
|
|
178
|
+
|
|
179
|
+
func = cast(TaskCallable, impl)
|
|
180
|
+
return CallableTask(func, name=resolved_name)
|
|
181
|
+
else:
|
|
182
|
+
raise TaskNotFoundError(resolved_name)
|
|
183
|
+
|
|
184
|
+
def get_by_class(self, class_name: str) -> Task:
|
|
185
|
+
"""
|
|
186
|
+
Get a task by its implementing class name.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
class_name: Fully qualified or simple class name
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
A Task instance
|
|
193
|
+
|
|
194
|
+
Raises:
|
|
195
|
+
TaskNotFoundError: If task not found
|
|
196
|
+
"""
|
|
197
|
+
# Try exact match first
|
|
198
|
+
if class_name in self._tasks:
|
|
199
|
+
return self.get(class_name)
|
|
200
|
+
|
|
201
|
+
# Try simple class name
|
|
202
|
+
simple_name = class_name.split(".")[-1]
|
|
203
|
+
if simple_name in self._tasks:
|
|
204
|
+
return self.get(simple_name)
|
|
205
|
+
|
|
206
|
+
# Try lowercase
|
|
207
|
+
if simple_name.lower() in self._tasks:
|
|
208
|
+
return self.get(simple_name.lower())
|
|
209
|
+
|
|
210
|
+
raise TaskNotFoundError(class_name)
|
|
211
|
+
|
|
212
|
+
def has(self, name: str) -> bool:
|
|
213
|
+
"""Check if a task is registered."""
|
|
214
|
+
resolved_name = self._aliases.get(name, name)
|
|
215
|
+
return resolved_name in self._tasks
|
|
216
|
+
|
|
217
|
+
def list_tasks(self) -> list[str]:
|
|
218
|
+
"""Get all registered task names."""
|
|
219
|
+
return list(self._tasks.keys())
|
|
220
|
+
|
|
221
|
+
def clear(self) -> None:
|
|
222
|
+
"""Clear all registrations."""
|
|
223
|
+
self._tasks.clear()
|
|
224
|
+
self._aliases.clear()
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
# Global registry instance
|
|
228
|
+
_default_registry: TaskRegistry | None = None
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def get_default_registry() -> TaskRegistry:
|
|
232
|
+
"""Get the default global task registry."""
|
|
233
|
+
global _default_registry
|
|
234
|
+
if _default_registry is None:
|
|
235
|
+
_default_registry = TaskRegistry()
|
|
236
|
+
return _default_registry
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def register_task(
|
|
240
|
+
name: str,
|
|
241
|
+
task: TaskImplementation,
|
|
242
|
+
aliases: list[str] | None = None,
|
|
243
|
+
) -> None:
|
|
244
|
+
"""Register a task in the default registry."""
|
|
245
|
+
get_default_registry().register(name, task, aliases)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def get_task(name: str) -> Task:
|
|
249
|
+
"""Get a task from the default registry."""
|
|
250
|
+
return get_default_registry().get(name)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def task(name: str, aliases: list[str] | None = None) -> Callable[[TaskCallable], TaskCallable]:
|
|
254
|
+
"""Decorator to register a task in the default registry."""
|
|
255
|
+
return get_default_registry().task(name, aliases)
|