auto-workflow 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- assets/logo.svg +6524 -0
- auto_workflow/__init__.py +40 -0
- auto_workflow/__main__.py +9 -0
- auto_workflow/artifacts.py +119 -0
- auto_workflow/build.py +158 -0
- auto_workflow/cache.py +111 -0
- auto_workflow/cli.py +78 -0
- auto_workflow/config.py +76 -0
- auto_workflow/context.py +32 -0
- auto_workflow/dag.py +80 -0
- auto_workflow/events.py +22 -0
- auto_workflow/exceptions.py +42 -0
- auto_workflow/execution.py +45 -0
- auto_workflow/fanout.py +59 -0
- auto_workflow/flow.py +165 -0
- auto_workflow/lifecycle.py +16 -0
- auto_workflow/logging_middleware.py +191 -0
- auto_workflow/metrics_provider.py +35 -0
- auto_workflow/middleware.py +59 -0
- auto_workflow/py.typed +0 -0
- auto_workflow/scheduler.py +362 -0
- auto_workflow/secrets.py +45 -0
- auto_workflow/task.py +158 -0
- auto_workflow/tracing.py +29 -0
- auto_workflow/types.py +27 -0
- auto_workflow/utils.py +39 -0
- auto_workflow-0.1.0.dist-info/LICENSE +674 -0
- auto_workflow-0.1.0.dist-info/METADATA +423 -0
- auto_workflow-0.1.0.dist-info/RECORD +31 -0
- auto_workflow-0.1.0.dist-info/WHEEL +4 -0
- auto_workflow-0.1.0.dist-info/entry_points.txt +3 -0
auto_workflow/dag.py
ADDED
@@ -0,0 +1,80 @@
|
|
1
|
+
"""Internal DAG representation and cycle detection."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from collections.abc import Iterable
|
6
|
+
from dataclasses import dataclass, field
|
7
|
+
|
8
|
+
from .exceptions import CycleDetectedError
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass(slots=True)
|
12
|
+
class Node:
|
13
|
+
name: str
|
14
|
+
upstream: set[str] = field(default_factory=set)
|
15
|
+
downstream: set[str] = field(default_factory=set)
|
16
|
+
|
17
|
+
|
18
|
+
class DAG:
|
19
|
+
def __init__(self) -> None:
|
20
|
+
self.nodes: dict[str, Node] = {}
|
21
|
+
|
22
|
+
def add_node(self, name: str) -> None:
|
23
|
+
if name not in self.nodes:
|
24
|
+
self.nodes[name] = Node(name=name)
|
25
|
+
|
26
|
+
def add_edge(self, upstream: str, downstream: str) -> None:
|
27
|
+
self.add_node(upstream)
|
28
|
+
self.add_node(downstream)
|
29
|
+
self.nodes[upstream].downstream.add(downstream)
|
30
|
+
self.nodes[downstream].upstream.add(upstream)
|
31
|
+
|
32
|
+
def topological_sort(self) -> list[str]:
|
33
|
+
# Kahn's algorithm with deterministic ordering
|
34
|
+
in_degree = {n: len(node.upstream) for n, node in self.nodes.items()}
|
35
|
+
ready = sorted([n for n, d in in_degree.items() if d == 0])
|
36
|
+
order: list[str] = []
|
37
|
+
while ready:
|
38
|
+
# pop the smallest name for stable results
|
39
|
+
current = ready.pop(0)
|
40
|
+
order.append(current)
|
41
|
+
for child in sorted(self.nodes[current].downstream):
|
42
|
+
in_degree[child] -= 1
|
43
|
+
if in_degree[child] == 0:
|
44
|
+
# keep ready sorted
|
45
|
+
from bisect import insort
|
46
|
+
|
47
|
+
insort(ready, child)
|
48
|
+
if len(order) != len(self.nodes):
|
49
|
+
# Return nodes that still have in-degree > 0 deterministically
|
50
|
+
remaining = sorted([n for n, d in in_degree.items() if d > 0])
|
51
|
+
raise CycleDetectedError(remaining)
|
52
|
+
return order
|
53
|
+
|
54
|
+
def subgraph(self, names: Iterable[str]) -> DAG:
|
55
|
+
sg = DAG()
|
56
|
+
for name in names:
|
57
|
+
if name not in self.nodes:
|
58
|
+
continue
|
59
|
+
sg.add_node(name)
|
60
|
+
for d in self.nodes[name].downstream:
|
61
|
+
if d in names:
|
62
|
+
sg.add_edge(name, d)
|
63
|
+
return sg
|
64
|
+
|
65
|
+
# Export utilities
|
66
|
+
def to_dot(self) -> str:
|
67
|
+
lines = ["digraph G {"]
|
68
|
+
for name, node in self.nodes.items():
|
69
|
+
if not node.downstream:
|
70
|
+
lines.append(f' "{name}";')
|
71
|
+
for d in node.downstream:
|
72
|
+
lines.append(f' "{name}" -> "{d}";')
|
73
|
+
lines.append("}")
|
74
|
+
return "\n".join(lines)
|
75
|
+
|
76
|
+
def to_dict(self) -> dict:
|
77
|
+
return {
|
78
|
+
name: {"upstream": sorted(n.upstream), "downstream": sorted(n.downstream)}
|
79
|
+
for name, n in self.nodes.items()
|
80
|
+
}
|
auto_workflow/events.py
ADDED
@@ -0,0 +1,22 @@
|
|
1
|
+
"""Lightweight event bus."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from collections.abc import Callable
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
_subscribers: dict[str, list[Callable[[dict[str, Any]], None]]] = {}
|
10
|
+
logger = logging.getLogger("auto_workflow.events")
|
11
|
+
|
12
|
+
|
13
|
+
def subscribe(event: str, callback: Callable[[dict[str, Any]], None]) -> None:
|
14
|
+
_subscribers.setdefault(event, []).append(callback)
|
15
|
+
|
16
|
+
|
17
|
+
def emit(event: str, payload: dict[str, Any]) -> None:
|
18
|
+
for cb in _subscribers.get(event, []):
|
19
|
+
try:
|
20
|
+
cb(payload)
|
21
|
+
except Exception as e: # pragma: no cover
|
22
|
+
logger.debug("event subscriber error", exc_info=e)
|
@@ -0,0 +1,42 @@
|
|
1
|
+
"""Domain-specific exceptions for auto_workflow."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
|
6
|
+
class AutoWorkflowError(Exception):
|
7
|
+
"""Base exception."""
|
8
|
+
|
9
|
+
|
10
|
+
class CycleDetectedError(AutoWorkflowError):
|
11
|
+
def __init__(self, cycle: list[str]):
|
12
|
+
super().__init__(f"Cycle detected in DAG: {' -> '.join(cycle)}")
|
13
|
+
self.cycle = cycle
|
14
|
+
|
15
|
+
|
16
|
+
class TaskExecutionError(AutoWorkflowError):
|
17
|
+
def __init__(self, task_name: str, original: BaseException):
|
18
|
+
super().__init__(f"Task '{task_name}' failed: {original!r}")
|
19
|
+
self.task_name = task_name
|
20
|
+
self.original = original
|
21
|
+
|
22
|
+
|
23
|
+
class TimeoutError(TaskExecutionError):
|
24
|
+
pass
|
25
|
+
|
26
|
+
|
27
|
+
class RetryExhaustedError(TaskExecutionError):
|
28
|
+
pass
|
29
|
+
|
30
|
+
|
31
|
+
class InvalidGraphError(AutoWorkflowError):
|
32
|
+
pass
|
33
|
+
|
34
|
+
|
35
|
+
class AggregateTaskError(AutoWorkflowError):
|
36
|
+
"""Raised when multiple tasks fail under AGGREGATE failure policy."""
|
37
|
+
|
38
|
+
def __init__(self, errors: list[TaskExecutionError]):
|
39
|
+
self.errors = errors
|
40
|
+
summary = "; ".join(f"{e.task_name}: {e.original!r}" for e in errors[:5])
|
41
|
+
more = "" if len(errors) <= 5 else f" (+{len(errors) - 5} more)"
|
42
|
+
super().__init__(f"Multiple task failures: {summary}{more}")
|
@@ -0,0 +1,45 @@
|
|
1
|
+
"""Execution helpers for process-based task offload."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import atexit
|
6
|
+
from concurrent.futures import ProcessPoolExecutor
|
7
|
+
from contextlib import suppress
|
8
|
+
from typing import Any
|
9
|
+
|
10
|
+
import cloudpickle
|
11
|
+
|
12
|
+
from .config import load_config
|
13
|
+
|
14
|
+
_SHARED_POOL: ProcessPoolExecutor | None = None
|
15
|
+
|
16
|
+
|
17
|
+
def _shutdown_pool() -> None:
|
18
|
+
global _SHARED_POOL
|
19
|
+
if _SHARED_POOL is not None:
|
20
|
+
with suppress(Exception):
|
21
|
+
_SHARED_POOL.shutdown(wait=True, cancel_futures=True)
|
22
|
+
_SHARED_POOL = None
|
23
|
+
|
24
|
+
|
25
|
+
def get_process_pool() -> ProcessPoolExecutor:
|
26
|
+
global _SHARED_POOL
|
27
|
+
if _SHARED_POOL is None:
|
28
|
+
cfg = load_config()
|
29
|
+
max_workers = cfg.get("process_pool_max_workers")
|
30
|
+
if isinstance(max_workers, str):
|
31
|
+
try:
|
32
|
+
max_workers = int(max_workers) if max_workers.isdigit() else None
|
33
|
+
except Exception:
|
34
|
+
max_workers = None
|
35
|
+
if not isinstance(max_workers, int) or max_workers <= 0:
|
36
|
+
max_workers = None
|
37
|
+
_SHARED_POOL = ProcessPoolExecutor(max_workers=max_workers)
|
38
|
+
# Ensure clean shutdown on interpreter exit
|
39
|
+
atexit.register(_shutdown_pool)
|
40
|
+
return _SHARED_POOL
|
41
|
+
|
42
|
+
|
43
|
+
def run_pickled(payload: bytes) -> Any: # executed in worker process
|
44
|
+
fn, args, kwargs = cloudpickle.loads(payload)
|
45
|
+
return fn(*args, **kwargs)
|
auto_workflow/fanout.py
ADDED
@@ -0,0 +1,59 @@
|
|
1
|
+
"""Fan-out helper utilities."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from collections.abc import Iterable
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
from .build import TaskInvocation, current_build_context
|
9
|
+
|
10
|
+
|
11
|
+
class DynamicFanOut(list): # placeholder container recognized by scheduler
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
task_def,
|
15
|
+
source_invocation: TaskInvocation | DynamicFanOut,
|
16
|
+
max_concurrency: int | None,
|
17
|
+
ctx,
|
18
|
+
):
|
19
|
+
super().__init__()
|
20
|
+
self._task_def = task_def
|
21
|
+
self._source = source_invocation
|
22
|
+
self._max_conc = max_concurrency
|
23
|
+
self._expanded = False
|
24
|
+
self._ctx = ctx
|
25
|
+
|
26
|
+
def expand(self, values: Iterable[Any]):
|
27
|
+
if self._expanded:
|
28
|
+
return
|
29
|
+
for v in values:
|
30
|
+
self.append(
|
31
|
+
self._ctx.register(self._task_def.name, self._task_def.fn, (v,), {}, self._task_def)
|
32
|
+
)
|
33
|
+
self._expanded = True
|
34
|
+
|
35
|
+
|
36
|
+
def fan_out(task_def, iterable: Iterable[Any], *, max_concurrency: int | None = None) -> list[Any]:
|
37
|
+
"""Create multiple task invocations from an iterable.
|
38
|
+
|
39
|
+
max_concurrency is reserved for future scheduling throttling; currently unused.
|
40
|
+
"""
|
41
|
+
ctx = current_build_context()
|
42
|
+
out = []
|
43
|
+
if ctx is None:
|
44
|
+
# immediate execution path
|
45
|
+
return [task_def(item) for item in iterable]
|
46
|
+
from .fanout import DynamicFanOut # self-import safe
|
47
|
+
|
48
|
+
if isinstance(iterable, TaskInvocation): # dynamic runtime fan-out
|
49
|
+
df = DynamicFanOut(task_def, iterable, max_concurrency, ctx)
|
50
|
+
ctx.dynamic_fanouts.append(df)
|
51
|
+
return df
|
52
|
+
if isinstance(iterable, DynamicFanOut): # nested dynamic
|
53
|
+
# nested placeholder - track as root only if original source TaskInvocation (first-level)
|
54
|
+
df = DynamicFanOut(task_def, iterable, max_concurrency, ctx)
|
55
|
+
ctx.dynamic_fanouts.append(df)
|
56
|
+
return df
|
57
|
+
for item in iterable:
|
58
|
+
out.append(task_def(item))
|
59
|
+
return out
|
auto_workflow/flow.py
ADDED
@@ -0,0 +1,165 @@
|
|
1
|
+
"""Flow abstraction and decorator."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import asyncio
|
6
|
+
import uuid
|
7
|
+
from collections.abc import Callable
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from typing import Any
|
10
|
+
|
11
|
+
from .build import BuildContext, collect_invocations, replace_invocations
|
12
|
+
from .config import load_config
|
13
|
+
from .context import RunContext, set_context
|
14
|
+
from .events import emit
|
15
|
+
from .scheduler import FailurePolicy, execute_dag
|
16
|
+
from .tracing import get_tracer
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass(slots=True)
|
20
|
+
class Flow:
|
21
|
+
name: str
|
22
|
+
build_fn: Callable[..., Any]
|
23
|
+
|
24
|
+
async def _run_internal(
|
25
|
+
self,
|
26
|
+
*args: Any,
|
27
|
+
params: dict[str, Any] | None = None,
|
28
|
+
failure_policy: str = FailurePolicy.FAIL_FAST,
|
29
|
+
max_concurrency: int | None = None,
|
30
|
+
**kwargs: Any,
|
31
|
+
) -> Any:
|
32
|
+
cfg = load_config()
|
33
|
+
if max_concurrency is None:
|
34
|
+
mc = cfg.get("max_dynamic_tasks")
|
35
|
+
if isinstance(mc, str):
|
36
|
+
if mc.isdigit():
|
37
|
+
try:
|
38
|
+
mc = int(mc)
|
39
|
+
except ValueError: # pragma: no cover
|
40
|
+
mc = None
|
41
|
+
else:
|
42
|
+
mc = None # non-numeric string -> ignore
|
43
|
+
max_concurrency = mc if isinstance(mc, int) and mc > 0 else None
|
44
|
+
ctx = RunContext(run_id=str(uuid.uuid4()), flow_name=self.name, params=params or {})
|
45
|
+
set_context(ctx)
|
46
|
+
tracer = get_tracer()
|
47
|
+
emit("flow_started", {"flow": self.name, "run_id": ctx.run_id})
|
48
|
+
async with tracer.span(f"flow:{self.name}", run_id=ctx.run_id):
|
49
|
+
with BuildContext() as bctx:
|
50
|
+
structure = self.build_fn(*args, **kwargs)
|
51
|
+
dynamic_roots = list(bctx.dynamic_fanouts)
|
52
|
+
invocations = collect_invocations(structure)
|
53
|
+
if not invocations:
|
54
|
+
# trivial, return original structure (no tasks used)
|
55
|
+
emit("flow_completed", {"flow": self.name, "run_id": ctx.run_id, "tasks": 0})
|
56
|
+
return structure
|
57
|
+
# Execute via scheduler
|
58
|
+
results = await execute_dag(
|
59
|
+
invocations,
|
60
|
+
failure_policy=failure_policy,
|
61
|
+
max_concurrency=max_concurrency,
|
62
|
+
dynamic_roots=dynamic_roots,
|
63
|
+
) # type: ignore
|
64
|
+
emit("flow_completed", {"flow": self.name, "run_id": ctx.run_id, "tasks": len(invocations)})
|
65
|
+
return replace_invocations(structure, results)
|
66
|
+
|
67
|
+
def run(
|
68
|
+
self,
|
69
|
+
*args: Any,
|
70
|
+
params: dict[str, Any] | None = None,
|
71
|
+
failure_policy: str = FailurePolicy.FAIL_FAST,
|
72
|
+
max_concurrency: int | None = None,
|
73
|
+
**kwargs: Any,
|
74
|
+
) -> Any:
|
75
|
+
return asyncio.run(
|
76
|
+
self._run_internal(
|
77
|
+
*args,
|
78
|
+
params=params,
|
79
|
+
failure_policy=failure_policy,
|
80
|
+
max_concurrency=max_concurrency,
|
81
|
+
**kwargs,
|
82
|
+
)
|
83
|
+
)
|
84
|
+
|
85
|
+
def describe(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
86
|
+
"""Return a JSON-serializable representation of the DAG that would be built.
|
87
|
+
Does not execute any tasks.
|
88
|
+
"""
|
89
|
+
with BuildContext():
|
90
|
+
structure = self.build_fn(*args, **kwargs)
|
91
|
+
invocations = collect_invocations(structure)
|
92
|
+
# Attach a synthetic upstream hint for nodes consuming dynamic placeholders
|
93
|
+
from .fanout import DynamicFanOut
|
94
|
+
|
95
|
+
for inv in invocations:
|
96
|
+
for arg in list(inv.args) + list(inv.kwargs.values()):
|
97
|
+
if isinstance(arg, DynamicFanOut):
|
98
|
+
# depend on source invocation so scheduling waits for expansion
|
99
|
+
# children edges injected later
|
100
|
+
src = arg._source
|
101
|
+
if hasattr(src, "name"):
|
102
|
+
inv.upstream.add(src.name)
|
103
|
+
nodes = []
|
104
|
+
for inv in invocations:
|
105
|
+
nodes.append(
|
106
|
+
{
|
107
|
+
"id": inv.name,
|
108
|
+
"task": inv.task_name,
|
109
|
+
"upstream": sorted(inv.upstream),
|
110
|
+
"persist": getattr(inv.definition, "persist", False),
|
111
|
+
"run_in": getattr(inv.definition, "run_in", "async"),
|
112
|
+
"retries": getattr(inv.definition, "retries", 0),
|
113
|
+
}
|
114
|
+
)
|
115
|
+
return {
|
116
|
+
"flow": self.name,
|
117
|
+
"nodes": nodes,
|
118
|
+
"count": len(nodes),
|
119
|
+
}
|
120
|
+
|
121
|
+
def export_dot(self, *args: Any, **kwargs: Any) -> str:
|
122
|
+
from .dag import DAG
|
123
|
+
|
124
|
+
with BuildContext():
|
125
|
+
structure = self.build_fn(*args, **kwargs)
|
126
|
+
invocations = collect_invocations(structure)
|
127
|
+
from .fanout import DynamicFanOut
|
128
|
+
|
129
|
+
# ensure upstream dependency on dynamic source(s)
|
130
|
+
for inv in invocations:
|
131
|
+
for arg in list(inv.args) + list(inv.kwargs.values()):
|
132
|
+
if isinstance(arg, DynamicFanOut):
|
133
|
+
src = arg._source
|
134
|
+
if hasattr(src, "name"):
|
135
|
+
inv.upstream.add(src.name)
|
136
|
+
dag = DAG()
|
137
|
+
for inv in invocations:
|
138
|
+
dag.add_node(inv.name)
|
139
|
+
for up in inv.upstream:
|
140
|
+
dag.add_edge(up, inv.name)
|
141
|
+
return dag.to_dot()
|
142
|
+
|
143
|
+
def export_graph(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
144
|
+
from .dag import DAG
|
145
|
+
|
146
|
+
with BuildContext():
|
147
|
+
structure = self.build_fn(*args, **kwargs)
|
148
|
+
invocations = collect_invocations(structure)
|
149
|
+
dag = DAG()
|
150
|
+
for inv in invocations:
|
151
|
+
dag.add_node(inv.name)
|
152
|
+
for up in inv.upstream:
|
153
|
+
dag.add_edge(up, inv.name)
|
154
|
+
return dag.to_dict()
|
155
|
+
|
156
|
+
|
157
|
+
def flow(
|
158
|
+
_fn: Callable[..., Any] | None = None, *, name: str | None = None
|
159
|
+
) -> Callable[[Callable[..., Any]], Flow]:
|
160
|
+
def wrap(fn: Callable[..., Any]) -> Flow:
|
161
|
+
return Flow(name=name or fn.__name__, build_fn=fn)
|
162
|
+
|
163
|
+
if _fn is not None:
|
164
|
+
return wrap(_fn)
|
165
|
+
return wrap
|
@@ -0,0 +1,16 @@
|
|
1
|
+
"""Lifecycle helpers for graceful shutdown of runtime components."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from contextlib import suppress
|
6
|
+
|
7
|
+
from .execution import _shutdown_pool # internal but intentional for lifecycle
|
8
|
+
|
9
|
+
|
10
|
+
def shutdown() -> None:
|
11
|
+
"""Gracefully shut down background resources (process pool, etc.).
|
12
|
+
|
13
|
+
Safe to call multiple times.
|
14
|
+
"""
|
15
|
+
with suppress(Exception):
|
16
|
+
_shutdown_pool()
|
@@ -0,0 +1,191 @@
|
|
1
|
+
"""Structured logging for flows and tasks.
|
2
|
+
|
3
|
+
Provides:
|
4
|
+
- structured_logging_middleware: logs task completion (ok/err) with duration.
|
5
|
+
- register_structured_logging(): registers middleware and event subscribers to log
|
6
|
+
flow start/completion and task start events.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from __future__ import annotations
|
10
|
+
|
11
|
+
import json
|
12
|
+
import logging
|
13
|
+
import time
|
14
|
+
from contextlib import suppress
|
15
|
+
from typing import Any
|
16
|
+
|
17
|
+
from .context import get_context
|
18
|
+
from .events import subscribe
|
19
|
+
|
20
|
+
logger = logging.getLogger("auto_workflow.tasks")
|
21
|
+
|
22
|
+
|
23
|
+
def _now_iso() -> str:
|
24
|
+
import datetime as _dt
|
25
|
+
|
26
|
+
return _dt.datetime.now(tz=_dt.UTC).isoformat()
|
27
|
+
|
28
|
+
|
29
|
+
async def structured_logging_middleware(nxt, task_def, args, kwargs):
|
30
|
+
ctx = None
|
31
|
+
with suppress(Exception): # outside flow safe
|
32
|
+
ctx = get_context()
|
33
|
+
start = time.time()
|
34
|
+
meta: dict[str, Any] = {
|
35
|
+
"task": task_def.name,
|
36
|
+
"run_id": getattr(ctx, "run_id", None),
|
37
|
+
"flow": getattr(ctx, "flow_name", None),
|
38
|
+
"ts": _now_iso(),
|
39
|
+
}
|
40
|
+
try:
|
41
|
+
result = await nxt()
|
42
|
+
duration = (time.time() - start) * 1000.0
|
43
|
+
meta["duration_ms"] = duration
|
44
|
+
logger.info(json.dumps({"event": "task_ok", **meta}))
|
45
|
+
return result
|
46
|
+
except Exception as e:
|
47
|
+
duration = (time.time() - start) * 1000.0
|
48
|
+
meta["duration_ms"] = duration
|
49
|
+
meta["error"] = repr(e)
|
50
|
+
logger.error(json.dumps({"event": "task_err", **meta}))
|
51
|
+
raise
|
52
|
+
|
53
|
+
|
54
|
+
# Event subscribers for flow/task lifecycle
|
55
|
+
def _on_flow_started(payload: dict[str, Any]) -> None:
|
56
|
+
logger.info(
|
57
|
+
json.dumps(
|
58
|
+
{
|
59
|
+
"event": "flow_started",
|
60
|
+
"flow": payload.get("flow"),
|
61
|
+
"run_id": payload.get("run_id"),
|
62
|
+
"ts": _now_iso(),
|
63
|
+
}
|
64
|
+
)
|
65
|
+
)
|
66
|
+
|
67
|
+
|
68
|
+
def _on_flow_completed(payload: dict[str, Any]) -> None:
|
69
|
+
logger.info(
|
70
|
+
json.dumps(
|
71
|
+
{
|
72
|
+
"event": "flow_completed",
|
73
|
+
"flow": payload.get("flow"),
|
74
|
+
"run_id": payload.get("run_id"),
|
75
|
+
"tasks": payload.get("tasks"),
|
76
|
+
"ts": _now_iso(),
|
77
|
+
}
|
78
|
+
)
|
79
|
+
)
|
80
|
+
|
81
|
+
|
82
|
+
def _on_task_started(payload: dict[str, Any]) -> None:
|
83
|
+
logger.info(
|
84
|
+
json.dumps(
|
85
|
+
{
|
86
|
+
"event": "task_started",
|
87
|
+
"task": payload.get("task"),
|
88
|
+
"node": payload.get("node"),
|
89
|
+
"ts": _now_iso(),
|
90
|
+
}
|
91
|
+
)
|
92
|
+
)
|
93
|
+
|
94
|
+
|
95
|
+
_registered = False
|
96
|
+
|
97
|
+
|
98
|
+
def register_structured_logging() -> None:
|
99
|
+
"""Register structured logging for flow and task lifecycle.
|
100
|
+
|
101
|
+
- Subscribes to flow_started, flow_completed, and task_started events.
|
102
|
+
- Registers the task middleware to log completion with duration.
|
103
|
+
"""
|
104
|
+
global _registered
|
105
|
+
if _registered:
|
106
|
+
return
|
107
|
+
subscribe("flow_started", _on_flow_started)
|
108
|
+
subscribe("flow_completed", _on_flow_completed)
|
109
|
+
subscribe("task_started", _on_task_started)
|
110
|
+
# Defer import to avoid cycles
|
111
|
+
from .middleware import register as _register
|
112
|
+
|
113
|
+
_register(structured_logging_middleware)
|
114
|
+
_registered = True
|
115
|
+
|
116
|
+
|
117
|
+
def enable_default_logging(level: str = "INFO") -> None:
|
118
|
+
"""Attach a simple stdout handler to the auto_workflow logger if none present.
|
119
|
+
|
120
|
+
This makes structured JSON lines visible when running scripts.
|
121
|
+
"""
|
122
|
+
if not logger.handlers:
|
123
|
+
h = logging.StreamHandler()
|
124
|
+
h.setFormatter(logging.Formatter("%(message)s"))
|
125
|
+
# mark so we can replace with pretty handler if requested
|
126
|
+
h._aw_default = True
|
127
|
+
logger.addHandler(h)
|
128
|
+
try:
|
129
|
+
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
130
|
+
except Exception:
|
131
|
+
logger.setLevel(logging.INFO)
|
132
|
+
# Avoid duplicate propagation to root
|
133
|
+
logger.propagate = False
|
134
|
+
|
135
|
+
|
136
|
+
class StructuredPrettyFormatter(logging.Formatter):
|
137
|
+
def __init__(self, datefmt: str | None = None) -> None:
|
138
|
+
super().__init__(datefmt=datefmt)
|
139
|
+
|
140
|
+
def format(self, record: logging.LogRecord) -> str:
|
141
|
+
# Try to parse JSON message emitted by structured logger
|
142
|
+
try:
|
143
|
+
data = json.loads(record.getMessage())
|
144
|
+
except Exception:
|
145
|
+
return super().format(record)
|
146
|
+
ts = self.formatTime(record, self.datefmt)
|
147
|
+
lvl = record.levelname
|
148
|
+
parts = [ts, lvl]
|
149
|
+
ev = data.get("event")
|
150
|
+
if ev:
|
151
|
+
parts.append(ev)
|
152
|
+
# Common fields
|
153
|
+
kvs = []
|
154
|
+
for k in ("flow", "run_id", "task", "node"):
|
155
|
+
v = data.get(k)
|
156
|
+
if v is not None:
|
157
|
+
kvs.append(f"{k}={v}")
|
158
|
+
if "duration_ms" in data:
|
159
|
+
try:
|
160
|
+
kvs.append(f"duration={float(data['duration_ms']):.1f}ms")
|
161
|
+
except Exception:
|
162
|
+
kvs.append(f"duration={data['duration_ms']}ms")
|
163
|
+
if "error" in data:
|
164
|
+
kvs.append(f"error={data['error']}")
|
165
|
+
if kvs:
|
166
|
+
parts.append(" ".join(kvs))
|
167
|
+
return " | ".join(parts)
|
168
|
+
|
169
|
+
|
170
|
+
def enable_pretty_logging(level: str = "INFO") -> None:
|
171
|
+
"""Replace the default JSON-line handler with a human-friendly formatter.
|
172
|
+
|
173
|
+
Safe to call multiple times.
|
174
|
+
"""
|
175
|
+
# Remove our default handler if present
|
176
|
+
new_handlers = []
|
177
|
+
for h in logger.handlers:
|
178
|
+
if getattr(h, "_aw_default", False):
|
179
|
+
continue
|
180
|
+
new_handlers.append(h)
|
181
|
+
logger.handlers = new_handlers
|
182
|
+
# Add pretty handler
|
183
|
+
h = logging.StreamHandler()
|
184
|
+
h.setFormatter(StructuredPrettyFormatter(datefmt="%Y-%m-%d %H:%M:%S%z"))
|
185
|
+
h._aw_pretty = True
|
186
|
+
logger.addHandler(h)
|
187
|
+
try:
|
188
|
+
logger.setLevel(getattr(logging, level.upper(), logging.INFO))
|
189
|
+
except Exception:
|
190
|
+
logger.setLevel(logging.INFO)
|
191
|
+
logger.propagate = False
|
@@ -0,0 +1,35 @@
|
|
1
|
+
"""Pluggable metrics provider abstraction."""
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
from typing import Any, Protocol
|
6
|
+
|
7
|
+
|
8
|
+
class MetricsProvider(Protocol): # pragma: no cover - interface
|
9
|
+
def inc(self, name: str, value: float = 1.0, **labels: Any) -> None: ...
|
10
|
+
def observe(self, name: str, value: float, **labels: Any) -> None: ...
|
11
|
+
|
12
|
+
|
13
|
+
class InMemoryMetrics(MetricsProvider):
|
14
|
+
def __init__(self) -> None:
|
15
|
+
self.counters: dict[str, float] = {}
|
16
|
+
self.histograms: dict[str, list[float]] = {}
|
17
|
+
|
18
|
+
def inc(self, name: str, value: float = 1.0, **labels: Any) -> None:
|
19
|
+
key = name
|
20
|
+
self.counters[key] = self.counters.get(key, 0.0) + value
|
21
|
+
|
22
|
+
def observe(self, name: str, value: float, **labels: Any) -> None:
|
23
|
+
self.histograms.setdefault(name, []).append(value)
|
24
|
+
|
25
|
+
|
26
|
+
_provider: MetricsProvider = InMemoryMetrics()
|
27
|
+
|
28
|
+
|
29
|
+
def set_metrics_provider(p: MetricsProvider) -> None:
|
30
|
+
global _provider
|
31
|
+
_provider = p
|
32
|
+
|
33
|
+
|
34
|
+
def get_metrics_provider() -> MetricsProvider:
|
35
|
+
return _provider
|