auto-workflow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- assets/logo.svg +6524 -0
- auto_workflow/__init__.py +40 -0
- auto_workflow/__main__.py +9 -0
- auto_workflow/artifacts.py +119 -0
- auto_workflow/build.py +158 -0
- auto_workflow/cache.py +111 -0
- auto_workflow/cli.py +78 -0
- auto_workflow/config.py +76 -0
- auto_workflow/context.py +32 -0
- auto_workflow/dag.py +80 -0
- auto_workflow/events.py +22 -0
- auto_workflow/exceptions.py +42 -0
- auto_workflow/execution.py +45 -0
- auto_workflow/fanout.py +59 -0
- auto_workflow/flow.py +165 -0
- auto_workflow/lifecycle.py +16 -0
- auto_workflow/logging_middleware.py +191 -0
- auto_workflow/metrics_provider.py +35 -0
- auto_workflow/middleware.py +59 -0
- auto_workflow/py.typed +0 -0
- auto_workflow/scheduler.py +362 -0
- auto_workflow/secrets.py +45 -0
- auto_workflow/task.py +158 -0
- auto_workflow/tracing.py +29 -0
- auto_workflow/types.py +27 -0
- auto_workflow/utils.py +39 -0
- auto_workflow-0.1.0.dist-info/LICENSE +674 -0
- auto_workflow-0.1.0.dist-info/METADATA +423 -0
- auto_workflow-0.1.0.dist-info/RECORD +31 -0
- auto_workflow-0.1.0.dist-info/WHEEL +4 -0
- auto_workflow-0.1.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,59 @@
|
|
1
|
+
"""Middleware chaining for task execution (extensible)."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from collections.abc import Awaitable, Callable
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
from .events import emit
|
9
|
+
|
10
|
+
TaskCallable = Callable[[], Awaitable[Any]]
|
11
|
+
Middleware = Callable[[TaskCallable, Any, tuple, dict], Awaitable[Any]]
|
12
|
+
|
13
|
+
_registry: list[Middleware] = []
|
14
|
+
|
15
|
+
|
16
|
+
def register(mw: Middleware) -> None:
|
17
|
+
_registry.append(mw)
|
18
|
+
|
19
|
+
|
20
|
+
def clear() -> None: # pragma: no cover
|
21
|
+
_registry.clear()
|
22
|
+
|
23
|
+
|
24
|
+
async def _call_chain(
|
25
|
+
index: int, core: TaskCallable, task_def: Any, args: tuple, kwargs: dict
|
26
|
+
) -> Any:
|
27
|
+
if index == len(_registry):
|
28
|
+
return await core()
|
29
|
+
mw = _registry[index]
|
30
|
+
executing_core = False
|
31
|
+
|
32
|
+
async def nxt():
|
33
|
+
nonlocal executing_core
|
34
|
+
executing_core = True
|
35
|
+
return await _call_chain(index + 1, core, task_def, args, kwargs)
|
36
|
+
|
37
|
+
try:
|
38
|
+
return await mw(lambda: nxt(), task_def, args, kwargs)
|
39
|
+
except Exception as e: # noqa: BLE001
|
40
|
+
# If exception occurred during/after core execution, propagate (task error)
|
41
|
+
if executing_core:
|
42
|
+
raise
|
43
|
+
# Otherwise classify as middleware error and continue chain
|
44
|
+
emit(
|
45
|
+
"middleware_error",
|
46
|
+
{
|
47
|
+
"task": getattr(task_def, "name", "unknown"),
|
48
|
+
"middleware_index": index,
|
49
|
+
"error": repr(e),
|
50
|
+
},
|
51
|
+
)
|
52
|
+
return await _call_chain(index + 1, core, task_def, args, kwargs)
|
53
|
+
|
54
|
+
|
55
|
+
def get_task_middleware_chain() -> Callable[[TaskCallable, Any, tuple, dict], Awaitable[Any]]:
|
56
|
+
async def runner(core: TaskCallable, task_def: Any, args: tuple, kwargs: dict) -> Any:
|
57
|
+
return await _call_chain(0, core, task_def, args, kwargs)
|
58
|
+
|
59
|
+
return runner
|
auto_workflow/py.typed
ADDED
File without changes
|
@@ -0,0 +1,362 @@
|
|
1
|
+
"""Simple topological scheduler executing TaskInvocations respecting dependencies."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import time
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
from .artifacts import get_store
|
10
|
+
from .build import TaskInvocation
|
11
|
+
from .cache import get_result_cache
|
12
|
+
from .dag import DAG
|
13
|
+
from .events import emit
|
14
|
+
from .exceptions import AggregateTaskError, TaskExecutionError
|
15
|
+
from .metrics_provider import get_metrics_provider
|
16
|
+
from .middleware import get_task_middleware_chain
|
17
|
+
from .tracing import get_tracer
|
18
|
+
|
19
|
+
# result cache implementation centralized in auto_workflow.cache
|
20
|
+
|
21
|
+
|
22
|
+
class FailurePolicy:
|
23
|
+
FAIL_FAST = "fail_fast"
|
24
|
+
CONTINUE = "continue"
|
25
|
+
AGGREGATE = "aggregate"
|
26
|
+
|
27
|
+
|
28
|
+
async def execute_dag(
|
29
|
+
invocations: list[TaskInvocation],
|
30
|
+
*,
|
31
|
+
failure_policy: str = FailurePolicy.FAIL_FAST,
|
32
|
+
max_concurrency: int | None = None,
|
33
|
+
cancel_event: asyncio.Event | None = None,
|
34
|
+
dynamic_roots: list[Any] | None = None,
|
35
|
+
) -> dict[str, Any]:
|
36
|
+
dag = DAG()
|
37
|
+
for inv in invocations:
|
38
|
+
dag.add_node(inv.name)
|
39
|
+
for up in inv.upstream:
|
40
|
+
dag.add_edge(up, inv.name)
|
41
|
+
order = dag.topological_sort()
|
42
|
+
inv_map = {inv.name: inv for inv in invocations}
|
43
|
+
results: dict[str, Any] = {}
|
44
|
+
cache = get_result_cache()
|
45
|
+
# simple ready set processing
|
46
|
+
remaining_deps = {name: set(inv_map[name].upstream) for name in order}
|
47
|
+
ready = [n for n, deps in remaining_deps.items() if not deps]
|
48
|
+
# removed unused 'running' variable
|
49
|
+
pending_tasks: dict[str, asyncio.Task[Any]] = {}
|
50
|
+
|
51
|
+
sem = asyncio.Semaphore(max_concurrency or len(order))
|
52
|
+
task_errors: list[TaskExecutionError] = []
|
53
|
+
|
54
|
+
# For tasks with cache_ttl we may have multiple identical invocations ready simultaneously
|
55
|
+
# (e.g., synchronous functions now offloaded to threads). Use an in-flight map to deduplicate.
|
56
|
+
inflight: dict[str, asyncio.Future[Any]] = {}
|
57
|
+
inflight_lock = asyncio.Lock()
|
58
|
+
# Per dynamic fan-out concurrency control
|
59
|
+
placeholder_sems: dict[int, asyncio.Semaphore] = {}
|
60
|
+
child_to_placeholder: dict[str, int] = {}
|
61
|
+
|
62
|
+
async def schedule(name: str) -> None:
|
63
|
+
inv = inv_map[name]
|
64
|
+
# prepare args by replacing dependencies
|
65
|
+
resolved_args = _hydrate(inv.args, results)
|
66
|
+
resolved_kwargs = _hydrate(inv.kwargs, results)
|
67
|
+
# caching (per task definition attributes)
|
68
|
+
cache_ttl = getattr(inv.definition, "cache_ttl", None)
|
69
|
+
cache_key = inv.definition.cache_key(*resolved_args, **resolved_kwargs)
|
70
|
+
if cache_ttl is not None:
|
71
|
+
# Atomic check-and-register for in-flight de-dup
|
72
|
+
async with inflight_lock:
|
73
|
+
cached = cache.get(cache_key, cache_ttl)
|
74
|
+
if cached is not None:
|
75
|
+
from contextlib import suppress
|
76
|
+
|
77
|
+
with suppress(Exception):
|
78
|
+
get_metrics_provider().inc("cache_hits")
|
79
|
+
results[name] = cached
|
80
|
+
return
|
81
|
+
existing = inflight.get(cache_key)
|
82
|
+
if existing is None:
|
83
|
+
# register placeholder before any execution begins so followers can await
|
84
|
+
inflight[cache_key] = asyncio.get_running_loop().create_future()
|
85
|
+
existing = None
|
86
|
+
# If another identical task already running, await its future
|
87
|
+
if existing is not None:
|
88
|
+
from contextlib import suppress
|
89
|
+
|
90
|
+
with suppress(Exception):
|
91
|
+
get_metrics_provider().inc("dedup_joins")
|
92
|
+
try:
|
93
|
+
value = await existing
|
94
|
+
except Exception as e: # propagate underlying failure
|
95
|
+
raise e
|
96
|
+
results[name] = value
|
97
|
+
return
|
98
|
+
try:
|
99
|
+
emit("task_started", {"task": inv.task_name, "node": name})
|
100
|
+
start = time.time()
|
101
|
+
async with sem:
|
102
|
+
tracer = get_tracer()
|
103
|
+
async with tracer.span(f"task:{inv.task_name}", node=name):
|
104
|
+
|
105
|
+
async def core_run():
|
106
|
+
return await inv.definition.run(*resolved_args, **resolved_kwargs)
|
107
|
+
|
108
|
+
# apply middleware chain
|
109
|
+
value = await get_task_middleware_chain()(
|
110
|
+
core_run, inv.definition, resolved_args, resolved_kwargs
|
111
|
+
)
|
112
|
+
duration = time.time() - start
|
113
|
+
# artifact persistence if requested
|
114
|
+
if getattr(inv.definition, "persist", False):
|
115
|
+
store = get_store()
|
116
|
+
ref = store.put(value)
|
117
|
+
value = ref
|
118
|
+
results[name] = value
|
119
|
+
if cache_ttl is not None:
|
120
|
+
cache.set(cache_key, value)
|
121
|
+
from contextlib import suppress
|
122
|
+
|
123
|
+
with suppress(Exception):
|
124
|
+
get_metrics_provider().inc("cache_sets")
|
125
|
+
# resolve and cleanup inflight placeholder atomically
|
126
|
+
async with inflight_lock:
|
127
|
+
fut = inflight.get(cache_key)
|
128
|
+
if fut and not fut.done():
|
129
|
+
fut.set_result(value)
|
130
|
+
inflight.pop(cache_key, None)
|
131
|
+
emit(
|
132
|
+
"task_succeeded",
|
133
|
+
{"task": inv.task_name, "node": name, "duration_ms": duration * 1000.0},
|
134
|
+
)
|
135
|
+
mp = get_metrics_provider()
|
136
|
+
mp.inc("tasks_succeeded")
|
137
|
+
mp.observe("task_duration_ms", duration * 1000.0)
|
138
|
+
except Exception as e: # noqa: BLE001
|
139
|
+
te = TaskExecutionError(inv.task_name, e)
|
140
|
+
if cache_ttl is not None:
|
141
|
+
async with inflight_lock:
|
142
|
+
fut = inflight.get(cache_key)
|
143
|
+
if fut and not fut.done():
|
144
|
+
fut.set_exception(e)
|
145
|
+
inflight.pop(cache_key, None)
|
146
|
+
if failure_policy == FailurePolicy.FAIL_FAST:
|
147
|
+
emit("task_failed", {"task": inv.task_name, "node": name, "error": repr(e)})
|
148
|
+
# record failure result so downstream inspection doesn't KeyError before propagation
|
149
|
+
results[name] = te
|
150
|
+
mp = get_metrics_provider()
|
151
|
+
mp.inc("tasks_failed")
|
152
|
+
raise te from None
|
153
|
+
task_errors.append(te)
|
154
|
+
results[name] = te
|
155
|
+
emit("task_failed", {"task": inv.task_name, "node": name, "error": repr(e)})
|
156
|
+
mp = get_metrics_provider()
|
157
|
+
mp.inc("tasks_failed")
|
158
|
+
|
159
|
+
# Map consumer invocation name -> list of DynamicFanOut placeholders it references
|
160
|
+
from .fanout import DynamicFanOut
|
161
|
+
|
162
|
+
consumer_placeholders: dict[str, list[DynamicFanOut]] = {}
|
163
|
+
all_placeholders: list[DynamicFanOut] = []
|
164
|
+
if dynamic_roots:
|
165
|
+
for p in dynamic_roots:
|
166
|
+
if isinstance(p, DynamicFanOut):
|
167
|
+
all_placeholders.append(p)
|
168
|
+
for inv in invocations:
|
169
|
+
for arg in list(inv.args) + list(inv.kwargs.values()):
|
170
|
+
if isinstance(arg, DynamicFanOut):
|
171
|
+
consumer_placeholders.setdefault(inv.name, []).append(arg)
|
172
|
+
all_placeholders.append(arg)
|
173
|
+
|
174
|
+
while ready or pending_tasks:
|
175
|
+
if cancel_event and cancel_event.is_set():
|
176
|
+
# Cancel all running tasks
|
177
|
+
for t in pending_tasks.values():
|
178
|
+
t.cancel()
|
179
|
+
# Wait for cancellation to propagate
|
180
|
+
if pending_tasks:
|
181
|
+
await asyncio.gather(*pending_tasks.values(), return_exceptions=True)
|
182
|
+
break
|
183
|
+
# Determine which ready nodes are truly runnable (no unexpanded placeholders they depend on)
|
184
|
+
runnable: list[str] = []
|
185
|
+
for node in list(ready):
|
186
|
+
placeholders = consumer_placeholders.get(node, [])
|
187
|
+
gate = False
|
188
|
+
for p in placeholders:
|
189
|
+
# Need expansion complete and all child results materialized
|
190
|
+
if (not p._expanded) or any(child.name not in results for child in p):
|
191
|
+
gate = True
|
192
|
+
break
|
193
|
+
if gate:
|
194
|
+
continue
|
195
|
+
runnable.append(node)
|
196
|
+
# Enforce priority ordering (higher first) and deterministic tie-break by node name
|
197
|
+
runnable.sort(
|
198
|
+
key=lambda n: (
|
199
|
+
-getattr(inv_map[n].definition, "priority", 0),
|
200
|
+
n,
|
201
|
+
)
|
202
|
+
)
|
203
|
+
for node in runnable:
|
204
|
+
# Skip if already scheduled (defensive against accidental duplicates in ready)
|
205
|
+
if node in pending_tasks:
|
206
|
+
# Remove duplicate occurrence
|
207
|
+
ready.remove(node)
|
208
|
+
continue
|
209
|
+
ready.remove(node)
|
210
|
+
# If node belongs to a dynamic placeholder with a concurrency limit, wrap schedule
|
211
|
+
pid = child_to_placeholder.get(node)
|
212
|
+
if pid is not None and pid in placeholder_sems:
|
213
|
+
sem_p = placeholder_sems[pid]
|
214
|
+
|
215
|
+
async def schedule_with_limit(n=node, s=sem_p):
|
216
|
+
async with s:
|
217
|
+
await schedule(n)
|
218
|
+
|
219
|
+
pending_tasks[node] = asyncio.create_task(schedule_with_limit())
|
220
|
+
else:
|
221
|
+
pending_tasks[node] = asyncio.create_task(schedule(node))
|
222
|
+
if not pending_tasks:
|
223
|
+
break
|
224
|
+
done, _ = await asyncio.wait(pending_tasks.values(), return_when=asyncio.FIRST_COMPLETED)
|
225
|
+
finished_names = [n for n, t in list(pending_tasks.items()) if t in done]
|
226
|
+
for fname in finished_names:
|
227
|
+
task = pending_tasks.pop(fname)
|
228
|
+
exc = task.exception()
|
229
|
+
if exc:
|
230
|
+
raise exc
|
231
|
+
# After task success, check if any dynamic fan-outs depend on it
|
232
|
+
from .fanout import DynamicFanOut
|
233
|
+
|
234
|
+
# Evaluate expansion conditions for all placeholders (supports nesting)
|
235
|
+
for placeholder in list(all_placeholders):
|
236
|
+
if placeholder._expanded:
|
237
|
+
continue
|
238
|
+
src = placeholder._source
|
239
|
+
ready_to_expand = False
|
240
|
+
if isinstance(src, TaskInvocation) and src.name == fname:
|
241
|
+
ready_to_expand = True
|
242
|
+
else:
|
243
|
+
# Nested: expand when all children have results
|
244
|
+
if (
|
245
|
+
isinstance(src, DynamicFanOut)
|
246
|
+
and src._expanded
|
247
|
+
and all(child.name in results for child in src)
|
248
|
+
):
|
249
|
+
ready_to_expand = True
|
250
|
+
if ready_to_expand:
|
251
|
+
# Derive nested source value: collect each child result into a list
|
252
|
+
if isinstance(src, TaskInvocation):
|
253
|
+
source_value = results[src.name]
|
254
|
+
else:
|
255
|
+
source_value = [results[c.name] for c in src]
|
256
|
+
if not isinstance(source_value, (list, tuple, set)):
|
257
|
+
raise TaskExecutionError(
|
258
|
+
getattr(src, "task_name", "dynamic"),
|
259
|
+
RuntimeError("Dynamic fan_out source must return an iterable"),
|
260
|
+
)
|
261
|
+
placeholder.expand(source_value)
|
262
|
+
for child_inv in placeholder:
|
263
|
+
# Track per-placeholder concurrency if specified
|
264
|
+
if getattr(placeholder, "_max_conc", None):
|
265
|
+
pid = id(placeholder)
|
266
|
+
if pid not in placeholder_sems:
|
267
|
+
placeholder_sems[pid] = asyncio.Semaphore(placeholder._max_conc)
|
268
|
+
child_to_placeholder[child_inv.name] = pid
|
269
|
+
if child_inv.name not in dag.nodes:
|
270
|
+
dag.add_node(child_inv.name)
|
271
|
+
# Edge from all upstream of src's children or src itself
|
272
|
+
if isinstance(src, TaskInvocation):
|
273
|
+
dag.add_edge(src.name, child_inv.name)
|
274
|
+
# Source already finished, so dependency is satisfied immediately
|
275
|
+
remaining_deps[child_inv.name] = set()
|
276
|
+
ready.append(child_inv.name)
|
277
|
+
else:
|
278
|
+
# depend on all children of src
|
279
|
+
deps = {c.name for c in src}
|
280
|
+
for d in deps:
|
281
|
+
dag.add_edge(d, child_inv.name)
|
282
|
+
# All deps completed (readiness condition); mark none remaining
|
283
|
+
remaining_deps[child_inv.name] = set()
|
284
|
+
ready.append(child_inv.name)
|
285
|
+
# ensure inv_map aware
|
286
|
+
if child_inv.name not in inv_map:
|
287
|
+
inv_map[child_inv.name] = child_inv
|
288
|
+
# Update consumers
|
289
|
+
for consumer in invocations:
|
290
|
+
if consumer is src:
|
291
|
+
continue
|
292
|
+
replaced = False
|
293
|
+
|
294
|
+
def _walk(o, target=placeholder): # bind loop var
|
295
|
+
nonlocal replaced
|
296
|
+
if o is target:
|
297
|
+
replaced = True
|
298
|
+
return list(target)
|
299
|
+
# Do not traverse into unrelated DynamicFanOut placeholders;
|
300
|
+
# treat them as atomic until their own expansion pass.
|
301
|
+
from .fanout import DynamicFanOut as DynamicFanOutPlaceholder
|
302
|
+
|
303
|
+
if isinstance(o, DynamicFanOutPlaceholder):
|
304
|
+
return o
|
305
|
+
if isinstance(o, list):
|
306
|
+
return [_walk(x, target) for x in o]
|
307
|
+
if isinstance(o, tuple):
|
308
|
+
return tuple(_walk(x, target) for x in o)
|
309
|
+
if isinstance(o, dict):
|
310
|
+
return {k: _walk(v, target) for k, v in o.items()}
|
311
|
+
return o
|
312
|
+
|
313
|
+
consumer.args = (
|
314
|
+
_walk(list(consumer.args)) if consumer.args else consumer.args
|
315
|
+
)
|
316
|
+
consumer.kwargs = (
|
317
|
+
_walk(consumer.kwargs) if consumer.kwargs else consumer.kwargs
|
318
|
+
)
|
319
|
+
if replaced:
|
320
|
+
for child_inv in placeholder:
|
321
|
+
dag.add_edge(child_inv.name, consumer.name)
|
322
|
+
remaining_deps.setdefault(consumer.name, set()).add(child_inv.name)
|
323
|
+
# Placeholders remain for nested detection; do not delete
|
324
|
+
# update downstream readiness
|
325
|
+
# Use set() to guard against duplicate edges causing repeated scheduling attempts
|
326
|
+
for child in set(dag.nodes[fname].downstream):
|
327
|
+
remaining_deps[child].discard(fname)
|
328
|
+
# Skip scheduling if any upstream failed and policy is not CONTINUE
|
329
|
+
upstream_failed = any(
|
330
|
+
isinstance(results.get(up), TaskExecutionError)
|
331
|
+
for up in dag.nodes[child].upstream
|
332
|
+
if up in results
|
333
|
+
)
|
334
|
+
if upstream_failed and failure_policy != FailurePolicy.CONTINUE:
|
335
|
+
results[child] = TaskExecutionError(child, RuntimeError("Upstream failed"))
|
336
|
+
continue
|
337
|
+
if (
|
338
|
+
not remaining_deps[child]
|
339
|
+
and child not in ready
|
340
|
+
and child not in pending_tasks
|
341
|
+
and child not in results
|
342
|
+
):
|
343
|
+
ready.append(child)
|
344
|
+
if failure_policy == FailurePolicy.AGGREGATE and task_errors:
|
345
|
+
raise AggregateTaskError(task_errors)
|
346
|
+
return results
|
347
|
+
|
348
|
+
|
349
|
+
def _hydrate(struct: Any, results: dict[str, Any]) -> Any:
|
350
|
+
from .build import TaskInvocation
|
351
|
+
|
352
|
+
if isinstance(struct, TaskInvocation):
|
353
|
+
return results[struct.name]
|
354
|
+
if isinstance(struct, list):
|
355
|
+
return [_hydrate(s, results) for s in struct]
|
356
|
+
if isinstance(struct, tuple):
|
357
|
+
return tuple(_hydrate(s, results) for s in struct)
|
358
|
+
if isinstance(struct, set):
|
359
|
+
return {_hydrate(s, results) for s in struct}
|
360
|
+
if isinstance(struct, dict):
|
361
|
+
return {k: _hydrate(v, results) for k, v in struct.items()}
|
362
|
+
return struct
|
auto_workflow/secrets.py
ADDED
@@ -0,0 +1,45 @@
|
|
1
|
+
"""Secrets provider abstraction."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import os
|
6
|
+
from typing import Protocol
|
7
|
+
|
8
|
+
|
9
|
+
class SecretsProvider(Protocol): # pragma: no cover - interface
|
10
|
+
def get(self, key: str) -> str | None: ...
|
11
|
+
|
12
|
+
|
13
|
+
class EnvSecrets(SecretsProvider):
|
14
|
+
def get(self, key: str) -> str | None:
|
15
|
+
return os.environ.get(key)
|
16
|
+
|
17
|
+
|
18
|
+
class StaticMappingSecrets(SecretsProvider):
|
19
|
+
def __init__(self, data: dict[str, str]):
|
20
|
+
self.data = data
|
21
|
+
|
22
|
+
def get(self, key: str) -> str | None:
|
23
|
+
return self.data.get(key)
|
24
|
+
|
25
|
+
|
26
|
+
class DummyVaultSecrets(SecretsProvider): # placeholder for future HashiCorp Vault integration
|
27
|
+
def __init__(self, prefix: str = ""):
|
28
|
+
self.prefix = prefix
|
29
|
+
# In a real implementation, store client/session
|
30
|
+
|
31
|
+
def get(self, key: str) -> str | None: # pragma: no cover - placeholder
|
32
|
+
# Would query Vault; here just environment fallback with prefix
|
33
|
+
return os.environ.get(self.prefix + key)
|
34
|
+
|
35
|
+
|
36
|
+
_provider: SecretsProvider = EnvSecrets()
|
37
|
+
|
38
|
+
|
39
|
+
def set_secrets_provider(p: SecretsProvider) -> None:
|
40
|
+
global _provider
|
41
|
+
_provider = p
|
42
|
+
|
43
|
+
|
44
|
+
def secret(key: str) -> str | None:
|
45
|
+
return _provider.get(key)
|
auto_workflow/task.py
ADDED
@@ -0,0 +1,158 @@
|
|
1
|
+
"""Task abstraction and decorator."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import builtins
|
7
|
+
import functools
|
8
|
+
import inspect
|
9
|
+
from collections.abc import Callable
|
10
|
+
from dataclasses import dataclass, field
|
11
|
+
from typing import Any
|
12
|
+
|
13
|
+
from .build import current_build_context
|
14
|
+
from .events import emit
|
15
|
+
from .exceptions import RetryExhaustedError, TimeoutError
|
16
|
+
from .tracing import get_tracer
|
17
|
+
from .types import TaskFn
|
18
|
+
from .utils import default_cache_key, maybe_await
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass(slots=True)
|
22
|
+
class TaskDefinition:
|
23
|
+
name: str
|
24
|
+
fn: TaskFn
|
25
|
+
original_fn: TaskFn | None = None
|
26
|
+
retries: int = 0
|
27
|
+
retry_backoff: float = 0.0
|
28
|
+
retry_jitter: float = 0.0
|
29
|
+
timeout: float | None = None
|
30
|
+
cache_ttl: int | None = None
|
31
|
+
cache_key_fn: Callable[..., str] = default_cache_key
|
32
|
+
tags: set[str] = field(default_factory=set)
|
33
|
+
run_in: str = "async" # one of: async, thread, process
|
34
|
+
persist: bool = False # store large result via artifact store (future)
|
35
|
+
priority: int = 0 # higher runs earlier when ready
|
36
|
+
|
37
|
+
async def run(self, *args: Any, **kwargs: Any) -> Any:
|
38
|
+
attempt = 0
|
39
|
+
while True:
|
40
|
+
|
41
|
+
async def _invoke():
|
42
|
+
tracer = get_tracer()
|
43
|
+
async with tracer.span(f"task:{self.name}"):
|
44
|
+
if self.run_in == "async":
|
45
|
+
return await maybe_await(self.fn(*args, **kwargs))
|
46
|
+
if self.run_in == "thread":
|
47
|
+
return await asyncio.to_thread(self.fn, *args, **kwargs)
|
48
|
+
if self.run_in == "process":
|
49
|
+
from .execution import get_process_pool
|
50
|
+
|
51
|
+
loop = asyncio.get_running_loop()
|
52
|
+
pool = get_process_pool()
|
53
|
+
import cloudpickle
|
54
|
+
|
55
|
+
from .execution import run_pickled
|
56
|
+
|
57
|
+
fn_bytes = cloudpickle.dumps((self.fn, args, kwargs))
|
58
|
+
return await loop.run_in_executor(
|
59
|
+
pool, functools.partial(run_pickled, fn_bytes)
|
60
|
+
)
|
61
|
+
raise RuntimeError(f"Unknown run_in mode: {self.run_in}")
|
62
|
+
|
63
|
+
try:
|
64
|
+
if self.timeout:
|
65
|
+
return await asyncio.wait_for(_invoke(), timeout=self.timeout)
|
66
|
+
return await _invoke()
|
67
|
+
except builtins.TimeoutError as te:
|
68
|
+
err: Exception = TimeoutError(self.name, te)
|
69
|
+
except Exception as e: # noqa: BLE001
|
70
|
+
err = e
|
71
|
+
# error path
|
72
|
+
if attempt < self.retries:
|
73
|
+
attempt += 1
|
74
|
+
emit("task_retry", {"task": self.name, "attempt": attempt, "max": self.retries})
|
75
|
+
sleep_dur = self.retry_backoff * (2 ** (attempt - 1))
|
76
|
+
if self.retry_jitter:
|
77
|
+
import random
|
78
|
+
|
79
|
+
sleep_dur += random.uniform(0, self.retry_jitter)
|
80
|
+
await asyncio.sleep(sleep_dur)
|
81
|
+
continue
|
82
|
+
if isinstance(err, TimeoutError):
|
83
|
+
raise err
|
84
|
+
raise RetryExhaustedError(self.name, err) from err
|
85
|
+
|
86
|
+
def cache_key(self, *args: Any, **kwargs: Any) -> str:
|
87
|
+
return self.cache_key_fn(self.fn, args, kwargs)
|
88
|
+
|
89
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any: # building vs immediate execution
|
90
|
+
ctx = current_build_context()
|
91
|
+
if ctx is None:
|
92
|
+
# immediate execution outside a flow build (synchronous helper)
|
93
|
+
tracer = get_tracer()
|
94
|
+
|
95
|
+
async def _run():
|
96
|
+
async with tracer.span(f"task:{self.name}"):
|
97
|
+
val = await self.run(*args, **kwargs)
|
98
|
+
if self.persist:
|
99
|
+
from .artifacts import get_store
|
100
|
+
|
101
|
+
store = get_store()
|
102
|
+
ref = store.put(val)
|
103
|
+
return ref
|
104
|
+
return val
|
105
|
+
|
106
|
+
return asyncio.run(_run())
|
107
|
+
return ctx.register(self.name, self.fn, args, kwargs, self)
|
108
|
+
|
109
|
+
|
110
|
+
def task(
|
111
|
+
_fn: TaskFn | None = None,
|
112
|
+
*,
|
113
|
+
name: str | None = None,
|
114
|
+
retries: int = 0,
|
115
|
+
retry_backoff: float = 0.0,
|
116
|
+
retry_jitter: float = 0.0,
|
117
|
+
timeout: float | None = None,
|
118
|
+
cache_ttl: int | None = None,
|
119
|
+
cache_key_fn: Callable[..., str] = default_cache_key,
|
120
|
+
tags: set[str] | None = None,
|
121
|
+
run_in: str | None = None,
|
122
|
+
persist: bool = False,
|
123
|
+
priority: int = 0,
|
124
|
+
) -> Callable[[TaskFn], TaskDefinition]:
|
125
|
+
"""Decorator to define a task.
|
126
|
+
|
127
|
+
Auto executor selection:
|
128
|
+
- If run_in explicitly provided -> honored.
|
129
|
+
- If function is an async def -> defaults to "async".
|
130
|
+
- Else (synchronous callable) -> defaults to "thread" to avoid blocking the event loop.
|
131
|
+
"""
|
132
|
+
|
133
|
+
def wrap(fn: TaskFn) -> TaskDefinition:
|
134
|
+
inferred_run_in = run_in
|
135
|
+
if inferred_run_in is None:
|
136
|
+
inferred_run_in = (
|
137
|
+
"async" if inspect.iscoroutinefunction(fn) else "thread"
|
138
|
+
) # safe default for blocking sync functions
|
139
|
+
td_obj = TaskDefinition(
|
140
|
+
name=name or fn.__name__,
|
141
|
+
fn=fn,
|
142
|
+
original_fn=fn,
|
143
|
+
retries=retries,
|
144
|
+
retry_backoff=retry_backoff,
|
145
|
+
retry_jitter=retry_jitter,
|
146
|
+
timeout=timeout,
|
147
|
+
cache_ttl=cache_ttl,
|
148
|
+
cache_key_fn=cache_key_fn,
|
149
|
+
tags=tags or set(),
|
150
|
+
run_in=inferred_run_in,
|
151
|
+
persist=persist,
|
152
|
+
priority=priority,
|
153
|
+
)
|
154
|
+
return td_obj
|
155
|
+
|
156
|
+
if _fn is not None:
|
157
|
+
return wrap(_fn)
|
158
|
+
return wrap
|
auto_workflow/tracing.py
ADDED
@@ -0,0 +1,29 @@
|
|
1
|
+
"""Tracing scaffold (OpenTelemetry friendly)."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import time
|
6
|
+
from contextlib import asynccontextmanager
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
|
10
|
+
class DummyTracer:
|
11
|
+
@asynccontextmanager
|
12
|
+
async def span(self, name: str, **attrs: Any): # pragma: no cover simple scaffold
|
13
|
+
start = time.time()
|
14
|
+
try:
|
15
|
+
yield {"start": start, "name": name, **attrs}
|
16
|
+
finally:
|
17
|
+
_ = time.time() - start
|
18
|
+
|
19
|
+
|
20
|
+
_tracer: DummyTracer = DummyTracer()
|
21
|
+
|
22
|
+
|
23
|
+
def get_tracer() -> DummyTracer:
|
24
|
+
return _tracer
|
25
|
+
|
26
|
+
|
27
|
+
def set_tracer(t: DummyTracer) -> None:
|
28
|
+
global _tracer
|
29
|
+
_tracer = t
|
auto_workflow/types.py
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
"""Common type aliases and Protocols."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from collections.abc import Awaitable, Callable
|
6
|
+
from typing import Protocol, TypeVar, runtime_checkable
|
7
|
+
|
8
|
+
T = TypeVar("T")
|
9
|
+
R = TypeVar("R")
|
10
|
+
|
11
|
+
TaskFn = Callable[..., R] | Callable[..., Awaitable[R]]
|
12
|
+
|
13
|
+
|
14
|
+
@runtime_checkable
|
15
|
+
class SupportsHash(Protocol):
|
16
|
+
def __hash__(self) -> int: ... # pragma: no cover
|
17
|
+
|
18
|
+
|
19
|
+
CacheKey = str
|
20
|
+
|
21
|
+
|
22
|
+
class CancelledSentinel:
|
23
|
+
def __repr__(self) -> str: # pragma: no cover
|
24
|
+
return "<CANCELLED>"
|
25
|
+
|
26
|
+
|
27
|
+
CANCELLED = CancelledSentinel()
|