auto-workflow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
auto_workflow/dag.py ADDED
@@ -0,0 +1,80 @@
1
+ """Internal DAG representation and cycle detection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterable
6
+ from dataclasses import dataclass, field
7
+
8
+ from .exceptions import CycleDetectedError
9
+
10
+
11
+ @dataclass(slots=True)
12
+ class Node:
13
+ name: str
14
+ upstream: set[str] = field(default_factory=set)
15
+ downstream: set[str] = field(default_factory=set)
16
+
17
+
18
+ class DAG:
19
+ def __init__(self) -> None:
20
+ self.nodes: dict[str, Node] = {}
21
+
22
+ def add_node(self, name: str) -> None:
23
+ if name not in self.nodes:
24
+ self.nodes[name] = Node(name=name)
25
+
26
+ def add_edge(self, upstream: str, downstream: str) -> None:
27
+ self.add_node(upstream)
28
+ self.add_node(downstream)
29
+ self.nodes[upstream].downstream.add(downstream)
30
+ self.nodes[downstream].upstream.add(upstream)
31
+
32
+ def topological_sort(self) -> list[str]:
33
+ # Kahn's algorithm with deterministic ordering
34
+ in_degree = {n: len(node.upstream) for n, node in self.nodes.items()}
35
+ ready = sorted([n for n, d in in_degree.items() if d == 0])
36
+ order: list[str] = []
37
+ while ready:
38
+ # pop the smallest name for stable results
39
+ current = ready.pop(0)
40
+ order.append(current)
41
+ for child in sorted(self.nodes[current].downstream):
42
+ in_degree[child] -= 1
43
+ if in_degree[child] == 0:
44
+ # keep ready sorted
45
+ from bisect import insort
46
+
47
+ insort(ready, child)
48
+ if len(order) != len(self.nodes):
49
+ # Return nodes that still have in-degree > 0 deterministically
50
+ remaining = sorted([n for n, d in in_degree.items() if d > 0])
51
+ raise CycleDetectedError(remaining)
52
+ return order
53
+
54
+ def subgraph(self, names: Iterable[str]) -> DAG:
55
+ sg = DAG()
56
+ for name in names:
57
+ if name not in self.nodes:
58
+ continue
59
+ sg.add_node(name)
60
+ for d in self.nodes[name].downstream:
61
+ if d in names:
62
+ sg.add_edge(name, d)
63
+ return sg
64
+
65
+ # Export utilities
66
+ def to_dot(self) -> str:
67
+ lines = ["digraph G {"]
68
+ for name, node in self.nodes.items():
69
+ if not node.downstream:
70
+ lines.append(f' "{name}";')
71
+ for d in node.downstream:
72
+ lines.append(f' "{name}" -> "{d}";')
73
+ lines.append("}")
74
+ return "\n".join(lines)
75
+
76
+ def to_dict(self) -> dict:
77
+ return {
78
+ name: {"upstream": sorted(n.upstream), "downstream": sorted(n.downstream)}
79
+ for name, n in self.nodes.items()
80
+ }
@@ -0,0 +1,22 @@
1
+ """Lightweight event bus."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from collections.abc import Callable
7
+ from typing import Any
8
+
9
+ _subscribers: dict[str, list[Callable[[dict[str, Any]], None]]] = {}
10
+ logger = logging.getLogger("auto_workflow.events")
11
+
12
+
13
+ def subscribe(event: str, callback: Callable[[dict[str, Any]], None]) -> None:
14
+ _subscribers.setdefault(event, []).append(callback)
15
+
16
+
17
+ def emit(event: str, payload: dict[str, Any]) -> None:
18
+ for cb in _subscribers.get(event, []):
19
+ try:
20
+ cb(payload)
21
+ except Exception as e: # pragma: no cover
22
+ logger.debug("event subscriber error", exc_info=e)
@@ -0,0 +1,42 @@
1
+ """Domain-specific exceptions for auto_workflow."""
2
+
3
+ from __future__ import annotations
4
+
5
+
6
+ class AutoWorkflowError(Exception):
7
+ """Base exception."""
8
+
9
+
10
+ class CycleDetectedError(AutoWorkflowError):
11
+ def __init__(self, cycle: list[str]):
12
+ super().__init__(f"Cycle detected in DAG: {' -> '.join(cycle)}")
13
+ self.cycle = cycle
14
+
15
+
16
+ class TaskExecutionError(AutoWorkflowError):
17
+ def __init__(self, task_name: str, original: BaseException):
18
+ super().__init__(f"Task '{task_name}' failed: {original!r}")
19
+ self.task_name = task_name
20
+ self.original = original
21
+
22
+
23
+ class TimeoutError(TaskExecutionError):
24
+ pass
25
+
26
+
27
+ class RetryExhaustedError(TaskExecutionError):
28
+ pass
29
+
30
+
31
+ class InvalidGraphError(AutoWorkflowError):
32
+ pass
33
+
34
+
35
+ class AggregateTaskError(AutoWorkflowError):
36
+ """Raised when multiple tasks fail under AGGREGATE failure policy."""
37
+
38
+ def __init__(self, errors: list[TaskExecutionError]):
39
+ self.errors = errors
40
+ summary = "; ".join(f"{e.task_name}: {e.original!r}" for e in errors[:5])
41
+ more = "" if len(errors) <= 5 else f" (+{len(errors) - 5} more)"
42
+ super().__init__(f"Multiple task failures: {summary}{more}")
@@ -0,0 +1,45 @@
1
+ """Execution helpers for process-based task offload."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import atexit
6
+ from concurrent.futures import ProcessPoolExecutor
7
+ from contextlib import suppress
8
+ from typing import Any
9
+
10
+ import cloudpickle
11
+
12
+ from .config import load_config
13
+
14
+ _SHARED_POOL: ProcessPoolExecutor | None = None
15
+
16
+
17
+ def _shutdown_pool() -> None:
18
+ global _SHARED_POOL
19
+ if _SHARED_POOL is not None:
20
+ with suppress(Exception):
21
+ _SHARED_POOL.shutdown(wait=True, cancel_futures=True)
22
+ _SHARED_POOL = None
23
+
24
+
25
+ def get_process_pool() -> ProcessPoolExecutor:
26
+ global _SHARED_POOL
27
+ if _SHARED_POOL is None:
28
+ cfg = load_config()
29
+ max_workers = cfg.get("process_pool_max_workers")
30
+ if isinstance(max_workers, str):
31
+ try:
32
+ max_workers = int(max_workers) if max_workers.isdigit() else None
33
+ except Exception:
34
+ max_workers = None
35
+ if not isinstance(max_workers, int) or max_workers <= 0:
36
+ max_workers = None
37
+ _SHARED_POOL = ProcessPoolExecutor(max_workers=max_workers)
38
+ # Ensure clean shutdown on interpreter exit
39
+ atexit.register(_shutdown_pool)
40
+ return _SHARED_POOL
41
+
42
+
43
+ def run_pickled(payload: bytes) -> Any: # executed in worker process
44
+ fn, args, kwargs = cloudpickle.loads(payload)
45
+ return fn(*args, **kwargs)
@@ -0,0 +1,59 @@
1
+ """Fan-out helper utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterable
6
+ from typing import Any
7
+
8
+ from .build import TaskInvocation, current_build_context
9
+
10
+
11
+ class DynamicFanOut(list): # placeholder container recognized by scheduler
12
+ def __init__(
13
+ self,
14
+ task_def,
15
+ source_invocation: TaskInvocation | DynamicFanOut,
16
+ max_concurrency: int | None,
17
+ ctx,
18
+ ):
19
+ super().__init__()
20
+ self._task_def = task_def
21
+ self._source = source_invocation
22
+ self._max_conc = max_concurrency
23
+ self._expanded = False
24
+ self._ctx = ctx
25
+
26
+ def expand(self, values: Iterable[Any]):
27
+ if self._expanded:
28
+ return
29
+ for v in values:
30
+ self.append(
31
+ self._ctx.register(self._task_def.name, self._task_def.fn, (v,), {}, self._task_def)
32
+ )
33
+ self._expanded = True
34
+
35
+
36
+ def fan_out(task_def, iterable: Iterable[Any], *, max_concurrency: int | None = None) -> list[Any]:
37
+ """Create multiple task invocations from an iterable.
38
+
39
+ max_concurrency is reserved for future scheduling throttling; currently unused.
40
+ """
41
+ ctx = current_build_context()
42
+ out = []
43
+ if ctx is None:
44
+ # immediate execution path
45
+ return [task_def(item) for item in iterable]
46
+ from .fanout import DynamicFanOut # self-import safe
47
+
48
+ if isinstance(iterable, TaskInvocation): # dynamic runtime fan-out
49
+ df = DynamicFanOut(task_def, iterable, max_concurrency, ctx)
50
+ ctx.dynamic_fanouts.append(df)
51
+ return df
52
+ if isinstance(iterable, DynamicFanOut): # nested dynamic
53
+ # nested placeholder - track as root only if original source TaskInvocation (first-level)
54
+ df = DynamicFanOut(task_def, iterable, max_concurrency, ctx)
55
+ ctx.dynamic_fanouts.append(df)
56
+ return df
57
+ for item in iterable:
58
+ out.append(task_def(item))
59
+ return out
auto_workflow/flow.py ADDED
@@ -0,0 +1,165 @@
1
+ """Flow abstraction and decorator."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import uuid
7
+ from collections.abc import Callable
8
+ from dataclasses import dataclass
9
+ from typing import Any
10
+
11
+ from .build import BuildContext, collect_invocations, replace_invocations
12
+ from .config import load_config
13
+ from .context import RunContext, set_context
14
+ from .events import emit
15
+ from .scheduler import FailurePolicy, execute_dag
16
+ from .tracing import get_tracer
17
+
18
+
19
+ @dataclass(slots=True)
20
+ class Flow:
21
+ name: str
22
+ build_fn: Callable[..., Any]
23
+
24
+ async def _run_internal(
25
+ self,
26
+ *args: Any,
27
+ params: dict[str, Any] | None = None,
28
+ failure_policy: str = FailurePolicy.FAIL_FAST,
29
+ max_concurrency: int | None = None,
30
+ **kwargs: Any,
31
+ ) -> Any:
32
+ cfg = load_config()
33
+ if max_concurrency is None:
34
+ mc = cfg.get("max_dynamic_tasks")
35
+ if isinstance(mc, str):
36
+ if mc.isdigit():
37
+ try:
38
+ mc = int(mc)
39
+ except ValueError: # pragma: no cover
40
+ mc = None
41
+ else:
42
+ mc = None # non-numeric string -> ignore
43
+ max_concurrency = mc if isinstance(mc, int) and mc > 0 else None
44
+ ctx = RunContext(run_id=str(uuid.uuid4()), flow_name=self.name, params=params or {})
45
+ set_context(ctx)
46
+ tracer = get_tracer()
47
+ emit("flow_started", {"flow": self.name, "run_id": ctx.run_id})
48
+ async with tracer.span(f"flow:{self.name}", run_id=ctx.run_id):
49
+ with BuildContext() as bctx:
50
+ structure = self.build_fn(*args, **kwargs)
51
+ dynamic_roots = list(bctx.dynamic_fanouts)
52
+ invocations = collect_invocations(structure)
53
+ if not invocations:
54
+ # trivial, return original structure (no tasks used)
55
+ emit("flow_completed", {"flow": self.name, "run_id": ctx.run_id, "tasks": 0})
56
+ return structure
57
+ # Execute via scheduler
58
+ results = await execute_dag(
59
+ invocations,
60
+ failure_policy=failure_policy,
61
+ max_concurrency=max_concurrency,
62
+ dynamic_roots=dynamic_roots,
63
+ ) # type: ignore
64
+ emit("flow_completed", {"flow": self.name, "run_id": ctx.run_id, "tasks": len(invocations)})
65
+ return replace_invocations(structure, results)
66
+
67
+ def run(
68
+ self,
69
+ *args: Any,
70
+ params: dict[str, Any] | None = None,
71
+ failure_policy: str = FailurePolicy.FAIL_FAST,
72
+ max_concurrency: int | None = None,
73
+ **kwargs: Any,
74
+ ) -> Any:
75
+ return asyncio.run(
76
+ self._run_internal(
77
+ *args,
78
+ params=params,
79
+ failure_policy=failure_policy,
80
+ max_concurrency=max_concurrency,
81
+ **kwargs,
82
+ )
83
+ )
84
+
85
+ def describe(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
86
+ """Return a JSON-serializable representation of the DAG that would be built.
87
+ Does not execute any tasks.
88
+ """
89
+ with BuildContext():
90
+ structure = self.build_fn(*args, **kwargs)
91
+ invocations = collect_invocations(structure)
92
+ # Attach a synthetic upstream hint for nodes consuming dynamic placeholders
93
+ from .fanout import DynamicFanOut
94
+
95
+ for inv in invocations:
96
+ for arg in list(inv.args) + list(inv.kwargs.values()):
97
+ if isinstance(arg, DynamicFanOut):
98
+ # depend on source invocation so scheduling waits for expansion
99
+ # children edges injected later
100
+ src = arg._source
101
+ if hasattr(src, "name"):
102
+ inv.upstream.add(src.name)
103
+ nodes = []
104
+ for inv in invocations:
105
+ nodes.append(
106
+ {
107
+ "id": inv.name,
108
+ "task": inv.task_name,
109
+ "upstream": sorted(inv.upstream),
110
+ "persist": getattr(inv.definition, "persist", False),
111
+ "run_in": getattr(inv.definition, "run_in", "async"),
112
+ "retries": getattr(inv.definition, "retries", 0),
113
+ }
114
+ )
115
+ return {
116
+ "flow": self.name,
117
+ "nodes": nodes,
118
+ "count": len(nodes),
119
+ }
120
+
121
+ def export_dot(self, *args: Any, **kwargs: Any) -> str:
122
+ from .dag import DAG
123
+
124
+ with BuildContext():
125
+ structure = self.build_fn(*args, **kwargs)
126
+ invocations = collect_invocations(structure)
127
+ from .fanout import DynamicFanOut
128
+
129
+ # ensure upstream dependency on dynamic source(s)
130
+ for inv in invocations:
131
+ for arg in list(inv.args) + list(inv.kwargs.values()):
132
+ if isinstance(arg, DynamicFanOut):
133
+ src = arg._source
134
+ if hasattr(src, "name"):
135
+ inv.upstream.add(src.name)
136
+ dag = DAG()
137
+ for inv in invocations:
138
+ dag.add_node(inv.name)
139
+ for up in inv.upstream:
140
+ dag.add_edge(up, inv.name)
141
+ return dag.to_dot()
142
+
143
+ def export_graph(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
144
+ from .dag import DAG
145
+
146
+ with BuildContext():
147
+ structure = self.build_fn(*args, **kwargs)
148
+ invocations = collect_invocations(structure)
149
+ dag = DAG()
150
+ for inv in invocations:
151
+ dag.add_node(inv.name)
152
+ for up in inv.upstream:
153
+ dag.add_edge(up, inv.name)
154
+ return dag.to_dict()
155
+
156
+
157
+ def flow(
158
+ _fn: Callable[..., Any] | None = None, *, name: str | None = None
159
+ ) -> Callable[[Callable[..., Any]], Flow]:
160
+ def wrap(fn: Callable[..., Any]) -> Flow:
161
+ return Flow(name=name or fn.__name__, build_fn=fn)
162
+
163
+ if _fn is not None:
164
+ return wrap(_fn)
165
+ return wrap
@@ -0,0 +1,16 @@
1
+ """Lifecycle helpers for graceful shutdown of runtime components."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from contextlib import suppress
6
+
7
+ from .execution import _shutdown_pool # internal but intentional for lifecycle
8
+
9
+
10
+ def shutdown() -> None:
11
+ """Gracefully shut down background resources (process pool, etc.).
12
+
13
+ Safe to call multiple times.
14
+ """
15
+ with suppress(Exception):
16
+ _shutdown_pool()
@@ -0,0 +1,191 @@
1
+ """Structured logging for flows and tasks.
2
+
3
+ Provides:
4
+ - structured_logging_middleware: logs task completion (ok/err) with duration.
5
+ - register_structured_logging(): registers middleware and event subscribers to log
6
+ flow start/completion and task start events.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import logging
13
+ import time
14
+ from contextlib import suppress
15
+ from typing import Any
16
+
17
+ from .context import get_context
18
+ from .events import subscribe
19
+
20
+ logger = logging.getLogger("auto_workflow.tasks")
21
+
22
+
23
+ def _now_iso() -> str:
24
+ import datetime as _dt
25
+
26
+ return _dt.datetime.now(tz=_dt.UTC).isoformat()
27
+
28
+
29
+ async def structured_logging_middleware(nxt, task_def, args, kwargs):
30
+ ctx = None
31
+ with suppress(Exception): # outside flow safe
32
+ ctx = get_context()
33
+ start = time.time()
34
+ meta: dict[str, Any] = {
35
+ "task": task_def.name,
36
+ "run_id": getattr(ctx, "run_id", None),
37
+ "flow": getattr(ctx, "flow_name", None),
38
+ "ts": _now_iso(),
39
+ }
40
+ try:
41
+ result = await nxt()
42
+ duration = (time.time() - start) * 1000.0
43
+ meta["duration_ms"] = duration
44
+ logger.info(json.dumps({"event": "task_ok", **meta}))
45
+ return result
46
+ except Exception as e:
47
+ duration = (time.time() - start) * 1000.0
48
+ meta["duration_ms"] = duration
49
+ meta["error"] = repr(e)
50
+ logger.error(json.dumps({"event": "task_err", **meta}))
51
+ raise
52
+
53
+
54
+ # Event subscribers for flow/task lifecycle
55
+ def _on_flow_started(payload: dict[str, Any]) -> None:
56
+ logger.info(
57
+ json.dumps(
58
+ {
59
+ "event": "flow_started",
60
+ "flow": payload.get("flow"),
61
+ "run_id": payload.get("run_id"),
62
+ "ts": _now_iso(),
63
+ }
64
+ )
65
+ )
66
+
67
+
68
+ def _on_flow_completed(payload: dict[str, Any]) -> None:
69
+ logger.info(
70
+ json.dumps(
71
+ {
72
+ "event": "flow_completed",
73
+ "flow": payload.get("flow"),
74
+ "run_id": payload.get("run_id"),
75
+ "tasks": payload.get("tasks"),
76
+ "ts": _now_iso(),
77
+ }
78
+ )
79
+ )
80
+
81
+
82
+ def _on_task_started(payload: dict[str, Any]) -> None:
83
+ logger.info(
84
+ json.dumps(
85
+ {
86
+ "event": "task_started",
87
+ "task": payload.get("task"),
88
+ "node": payload.get("node"),
89
+ "ts": _now_iso(),
90
+ }
91
+ )
92
+ )
93
+
94
+
95
+ _registered = False
96
+
97
+
98
+ def register_structured_logging() -> None:
99
+ """Register structured logging for flow and task lifecycle.
100
+
101
+ - Subscribes to flow_started, flow_completed, and task_started events.
102
+ - Registers the task middleware to log completion with duration.
103
+ """
104
+ global _registered
105
+ if _registered:
106
+ return
107
+ subscribe("flow_started", _on_flow_started)
108
+ subscribe("flow_completed", _on_flow_completed)
109
+ subscribe("task_started", _on_task_started)
110
+ # Defer import to avoid cycles
111
+ from .middleware import register as _register
112
+
113
+ _register(structured_logging_middleware)
114
+ _registered = True
115
+
116
+
117
+ def enable_default_logging(level: str = "INFO") -> None:
118
+ """Attach a simple stdout handler to the auto_workflow logger if none present.
119
+
120
+ This makes structured JSON lines visible when running scripts.
121
+ """
122
+ if not logger.handlers:
123
+ h = logging.StreamHandler()
124
+ h.setFormatter(logging.Formatter("%(message)s"))
125
+ # mark so we can replace with pretty handler if requested
126
+ h._aw_default = True
127
+ logger.addHandler(h)
128
+ try:
129
+ logger.setLevel(getattr(logging, level.upper(), logging.INFO))
130
+ except Exception:
131
+ logger.setLevel(logging.INFO)
132
+ # Avoid duplicate propagation to root
133
+ logger.propagate = False
134
+
135
+
136
+ class StructuredPrettyFormatter(logging.Formatter):
137
+ def __init__(self, datefmt: str | None = None) -> None:
138
+ super().__init__(datefmt=datefmt)
139
+
140
+ def format(self, record: logging.LogRecord) -> str:
141
+ # Try to parse JSON message emitted by structured logger
142
+ try:
143
+ data = json.loads(record.getMessage())
144
+ except Exception:
145
+ return super().format(record)
146
+ ts = self.formatTime(record, self.datefmt)
147
+ lvl = record.levelname
148
+ parts = [ts, lvl]
149
+ ev = data.get("event")
150
+ if ev:
151
+ parts.append(ev)
152
+ # Common fields
153
+ kvs = []
154
+ for k in ("flow", "run_id", "task", "node"):
155
+ v = data.get(k)
156
+ if v is not None:
157
+ kvs.append(f"{k}={v}")
158
+ if "duration_ms" in data:
159
+ try:
160
+ kvs.append(f"duration={float(data['duration_ms']):.1f}ms")
161
+ except Exception:
162
+ kvs.append(f"duration={data['duration_ms']}ms")
163
+ if "error" in data:
164
+ kvs.append(f"error={data['error']}")
165
+ if kvs:
166
+ parts.append(" ".join(kvs))
167
+ return " | ".join(parts)
168
+
169
+
170
+ def enable_pretty_logging(level: str = "INFO") -> None:
171
+ """Replace the default JSON-line handler with a human-friendly formatter.
172
+
173
+ Safe to call multiple times.
174
+ """
175
+ # Remove our default handler if present
176
+ new_handlers = []
177
+ for h in logger.handlers:
178
+ if getattr(h, "_aw_default", False):
179
+ continue
180
+ new_handlers.append(h)
181
+ logger.handlers = new_handlers
182
+ # Add pretty handler
183
+ h = logging.StreamHandler()
184
+ h.setFormatter(StructuredPrettyFormatter(datefmt="%Y-%m-%d %H:%M:%S%z"))
185
+ h._aw_pretty = True
186
+ logger.addHandler(h)
187
+ try:
188
+ logger.setLevel(getattr(logging, level.upper(), logging.INFO))
189
+ except Exception:
190
+ logger.setLevel(logging.INFO)
191
+ logger.propagate = False
@@ -0,0 +1,35 @@
1
+ """Pluggable metrics provider abstraction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Protocol
6
+
7
+
8
+ class MetricsProvider(Protocol): # pragma: no cover - interface
9
+ def inc(self, name: str, value: float = 1.0, **labels: Any) -> None: ...
10
+ def observe(self, name: str, value: float, **labels: Any) -> None: ...
11
+
12
+
13
+ class InMemoryMetrics(MetricsProvider):
14
+ def __init__(self) -> None:
15
+ self.counters: dict[str, float] = {}
16
+ self.histograms: dict[str, list[float]] = {}
17
+
18
+ def inc(self, name: str, value: float = 1.0, **labels: Any) -> None:
19
+ key = name
20
+ self.counters[key] = self.counters.get(key, 0.0) + value
21
+
22
+ def observe(self, name: str, value: float, **labels: Any) -> None:
23
+ self.histograms.setdefault(name, []).append(value)
24
+
25
+
26
+ _provider: MetricsProvider = InMemoryMetrics()
27
+
28
+
29
+ def set_metrics_provider(p: MetricsProvider) -> None:
30
+ global _provider
31
+ _provider = p
32
+
33
+
34
+ def get_metrics_provider() -> MetricsProvider:
35
+ return _provider