auto-workflow 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,59 @@
1
+ """Middleware chaining for task execution (extensible)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Awaitable, Callable
6
+ from typing import Any
7
+
8
+ from .events import emit
9
+
10
+ TaskCallable = Callable[[], Awaitable[Any]]
11
+ Middleware = Callable[[TaskCallable, Any, tuple, dict], Awaitable[Any]]
12
+
13
+ _registry: list[Middleware] = []
14
+
15
+
16
+ def register(mw: Middleware) -> None:
17
+ _registry.append(mw)
18
+
19
+
20
+ def clear() -> None: # pragma: no cover
21
+ _registry.clear()
22
+
23
+
24
+ async def _call_chain(
25
+ index: int, core: TaskCallable, task_def: Any, args: tuple, kwargs: dict
26
+ ) -> Any:
27
+ if index == len(_registry):
28
+ return await core()
29
+ mw = _registry[index]
30
+ executing_core = False
31
+
32
+ async def nxt():
33
+ nonlocal executing_core
34
+ executing_core = True
35
+ return await _call_chain(index + 1, core, task_def, args, kwargs)
36
+
37
+ try:
38
+ return await mw(lambda: nxt(), task_def, args, kwargs)
39
+ except Exception as e: # noqa: BLE001
40
+ # If exception occurred during/after core execution, propagate (task error)
41
+ if executing_core:
42
+ raise
43
+ # Otherwise classify as middleware error and continue chain
44
+ emit(
45
+ "middleware_error",
46
+ {
47
+ "task": getattr(task_def, "name", "unknown"),
48
+ "middleware_index": index,
49
+ "error": repr(e),
50
+ },
51
+ )
52
+ return await _call_chain(index + 1, core, task_def, args, kwargs)
53
+
54
+
55
+ def get_task_middleware_chain() -> Callable[[TaskCallable, Any, tuple, dict], Awaitable[Any]]:
56
+ async def runner(core: TaskCallable, task_def: Any, args: tuple, kwargs: dict) -> Any:
57
+ return await _call_chain(0, core, task_def, args, kwargs)
58
+
59
+ return runner
auto_workflow/py.typed ADDED
File without changes
@@ -0,0 +1,362 @@
1
+ """Simple topological scheduler executing TaskInvocations respecting dependencies."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import time
7
+ from typing import Any
8
+
9
+ from .artifacts import get_store
10
+ from .build import TaskInvocation
11
+ from .cache import get_result_cache
12
+ from .dag import DAG
13
+ from .events import emit
14
+ from .exceptions import AggregateTaskError, TaskExecutionError
15
+ from .metrics_provider import get_metrics_provider
16
+ from .middleware import get_task_middleware_chain
17
+ from .tracing import get_tracer
18
+
19
+ # result cache implementation centralized in auto_workflow.cache
20
+
21
+
22
+ class FailurePolicy:
23
+ FAIL_FAST = "fail_fast"
24
+ CONTINUE = "continue"
25
+ AGGREGATE = "aggregate"
26
+
27
+
28
+ async def execute_dag(
29
+ invocations: list[TaskInvocation],
30
+ *,
31
+ failure_policy: str = FailurePolicy.FAIL_FAST,
32
+ max_concurrency: int | None = None,
33
+ cancel_event: asyncio.Event | None = None,
34
+ dynamic_roots: list[Any] | None = None,
35
+ ) -> dict[str, Any]:
36
+ dag = DAG()
37
+ for inv in invocations:
38
+ dag.add_node(inv.name)
39
+ for up in inv.upstream:
40
+ dag.add_edge(up, inv.name)
41
+ order = dag.topological_sort()
42
+ inv_map = {inv.name: inv for inv in invocations}
43
+ results: dict[str, Any] = {}
44
+ cache = get_result_cache()
45
+ # simple ready set processing
46
+ remaining_deps = {name: set(inv_map[name].upstream) for name in order}
47
+ ready = [n for n, deps in remaining_deps.items() if not deps]
48
+ # removed unused 'running' variable
49
+ pending_tasks: dict[str, asyncio.Task[Any]] = {}
50
+
51
+ sem = asyncio.Semaphore(max_concurrency or len(order))
52
+ task_errors: list[TaskExecutionError] = []
53
+
54
+ # For tasks with cache_ttl we may have multiple identical invocations ready simultaneously
55
+ # (e.g., synchronous functions now offloaded to threads). Use an in-flight map to deduplicate.
56
+ inflight: dict[str, asyncio.Future[Any]] = {}
57
+ inflight_lock = asyncio.Lock()
58
+ # Per dynamic fan-out concurrency control
59
+ placeholder_sems: dict[int, asyncio.Semaphore] = {}
60
+ child_to_placeholder: dict[str, int] = {}
61
+
62
+ async def schedule(name: str) -> None:
63
+ inv = inv_map[name]
64
+ # prepare args by replacing dependencies
65
+ resolved_args = _hydrate(inv.args, results)
66
+ resolved_kwargs = _hydrate(inv.kwargs, results)
67
+ # caching (per task definition attributes)
68
+ cache_ttl = getattr(inv.definition, "cache_ttl", None)
69
+ cache_key = inv.definition.cache_key(*resolved_args, **resolved_kwargs)
70
+ if cache_ttl is not None:
71
+ # Atomic check-and-register for in-flight de-dup
72
+ async with inflight_lock:
73
+ cached = cache.get(cache_key, cache_ttl)
74
+ if cached is not None:
75
+ from contextlib import suppress
76
+
77
+ with suppress(Exception):
78
+ get_metrics_provider().inc("cache_hits")
79
+ results[name] = cached
80
+ return
81
+ existing = inflight.get(cache_key)
82
+ if existing is None:
83
+ # register placeholder before any execution begins so followers can await
84
+ inflight[cache_key] = asyncio.get_running_loop().create_future()
85
+ existing = None
86
+ # If another identical task already running, await its future
87
+ if existing is not None:
88
+ from contextlib import suppress
89
+
90
+ with suppress(Exception):
91
+ get_metrics_provider().inc("dedup_joins")
92
+ try:
93
+ value = await existing
94
+ except Exception as e: # propagate underlying failure
95
+ raise e
96
+ results[name] = value
97
+ return
98
+ try:
99
+ emit("task_started", {"task": inv.task_name, "node": name})
100
+ start = time.time()
101
+ async with sem:
102
+ tracer = get_tracer()
103
+ async with tracer.span(f"task:{inv.task_name}", node=name):
104
+
105
+ async def core_run():
106
+ return await inv.definition.run(*resolved_args, **resolved_kwargs)
107
+
108
+ # apply middleware chain
109
+ value = await get_task_middleware_chain()(
110
+ core_run, inv.definition, resolved_args, resolved_kwargs
111
+ )
112
+ duration = time.time() - start
113
+ # artifact persistence if requested
114
+ if getattr(inv.definition, "persist", False):
115
+ store = get_store()
116
+ ref = store.put(value)
117
+ value = ref
118
+ results[name] = value
119
+ if cache_ttl is not None:
120
+ cache.set(cache_key, value)
121
+ from contextlib import suppress
122
+
123
+ with suppress(Exception):
124
+ get_metrics_provider().inc("cache_sets")
125
+ # resolve and cleanup inflight placeholder atomically
126
+ async with inflight_lock:
127
+ fut = inflight.get(cache_key)
128
+ if fut and not fut.done():
129
+ fut.set_result(value)
130
+ inflight.pop(cache_key, None)
131
+ emit(
132
+ "task_succeeded",
133
+ {"task": inv.task_name, "node": name, "duration_ms": duration * 1000.0},
134
+ )
135
+ mp = get_metrics_provider()
136
+ mp.inc("tasks_succeeded")
137
+ mp.observe("task_duration_ms", duration * 1000.0)
138
+ except Exception as e: # noqa: BLE001
139
+ te = TaskExecutionError(inv.task_name, e)
140
+ if cache_ttl is not None:
141
+ async with inflight_lock:
142
+ fut = inflight.get(cache_key)
143
+ if fut and not fut.done():
144
+ fut.set_exception(e)
145
+ inflight.pop(cache_key, None)
146
+ if failure_policy == FailurePolicy.FAIL_FAST:
147
+ emit("task_failed", {"task": inv.task_name, "node": name, "error": repr(e)})
148
+ # record failure result so downstream inspection doesn't KeyError before propagation
149
+ results[name] = te
150
+ mp = get_metrics_provider()
151
+ mp.inc("tasks_failed")
152
+ raise te from None
153
+ task_errors.append(te)
154
+ results[name] = te
155
+ emit("task_failed", {"task": inv.task_name, "node": name, "error": repr(e)})
156
+ mp = get_metrics_provider()
157
+ mp.inc("tasks_failed")
158
+
159
+ # Map consumer invocation name -> list of DynamicFanOut placeholders it references
160
+ from .fanout import DynamicFanOut
161
+
162
+ consumer_placeholders: dict[str, list[DynamicFanOut]] = {}
163
+ all_placeholders: list[DynamicFanOut] = []
164
+ if dynamic_roots:
165
+ for p in dynamic_roots:
166
+ if isinstance(p, DynamicFanOut):
167
+ all_placeholders.append(p)
168
+ for inv in invocations:
169
+ for arg in list(inv.args) + list(inv.kwargs.values()):
170
+ if isinstance(arg, DynamicFanOut):
171
+ consumer_placeholders.setdefault(inv.name, []).append(arg)
172
+ all_placeholders.append(arg)
173
+
174
+ while ready or pending_tasks:
175
+ if cancel_event and cancel_event.is_set():
176
+ # Cancel all running tasks
177
+ for t in pending_tasks.values():
178
+ t.cancel()
179
+ # Wait for cancellation to propagate
180
+ if pending_tasks:
181
+ await asyncio.gather(*pending_tasks.values(), return_exceptions=True)
182
+ break
183
+ # Determine which ready nodes are truly runnable (no unexpanded placeholders they depend on)
184
+ runnable: list[str] = []
185
+ for node in list(ready):
186
+ placeholders = consumer_placeholders.get(node, [])
187
+ gate = False
188
+ for p in placeholders:
189
+ # Need expansion complete and all child results materialized
190
+ if (not p._expanded) or any(child.name not in results for child in p):
191
+ gate = True
192
+ break
193
+ if gate:
194
+ continue
195
+ runnable.append(node)
196
+ # Enforce priority ordering (higher first) and deterministic tie-break by node name
197
+ runnable.sort(
198
+ key=lambda n: (
199
+ -getattr(inv_map[n].definition, "priority", 0),
200
+ n,
201
+ )
202
+ )
203
+ for node in runnable:
204
+ # Skip if already scheduled (defensive against accidental duplicates in ready)
205
+ if node in pending_tasks:
206
+ # Remove duplicate occurrence
207
+ ready.remove(node)
208
+ continue
209
+ ready.remove(node)
210
+ # If node belongs to a dynamic placeholder with a concurrency limit, wrap schedule
211
+ pid = child_to_placeholder.get(node)
212
+ if pid is not None and pid in placeholder_sems:
213
+ sem_p = placeholder_sems[pid]
214
+
215
+ async def schedule_with_limit(n=node, s=sem_p):
216
+ async with s:
217
+ await schedule(n)
218
+
219
+ pending_tasks[node] = asyncio.create_task(schedule_with_limit())
220
+ else:
221
+ pending_tasks[node] = asyncio.create_task(schedule(node))
222
+ if not pending_tasks:
223
+ break
224
+ done, _ = await asyncio.wait(pending_tasks.values(), return_when=asyncio.FIRST_COMPLETED)
225
+ finished_names = [n for n, t in list(pending_tasks.items()) if t in done]
226
+ for fname in finished_names:
227
+ task = pending_tasks.pop(fname)
228
+ exc = task.exception()
229
+ if exc:
230
+ raise exc
231
+ # After task success, check if any dynamic fan-outs depend on it
232
+ from .fanout import DynamicFanOut
233
+
234
+ # Evaluate expansion conditions for all placeholders (supports nesting)
235
+ for placeholder in list(all_placeholders):
236
+ if placeholder._expanded:
237
+ continue
238
+ src = placeholder._source
239
+ ready_to_expand = False
240
+ if isinstance(src, TaskInvocation) and src.name == fname:
241
+ ready_to_expand = True
242
+ else:
243
+ # Nested: expand when all children have results
244
+ if (
245
+ isinstance(src, DynamicFanOut)
246
+ and src._expanded
247
+ and all(child.name in results for child in src)
248
+ ):
249
+ ready_to_expand = True
250
+ if ready_to_expand:
251
+ # Derive nested source value: collect each child result into a list
252
+ if isinstance(src, TaskInvocation):
253
+ source_value = results[src.name]
254
+ else:
255
+ source_value = [results[c.name] for c in src]
256
+ if not isinstance(source_value, (list, tuple, set)):
257
+ raise TaskExecutionError(
258
+ getattr(src, "task_name", "dynamic"),
259
+ RuntimeError("Dynamic fan_out source must return an iterable"),
260
+ )
261
+ placeholder.expand(source_value)
262
+ for child_inv in placeholder:
263
+ # Track per-placeholder concurrency if specified
264
+ if getattr(placeholder, "_max_conc", None):
265
+ pid = id(placeholder)
266
+ if pid not in placeholder_sems:
267
+ placeholder_sems[pid] = asyncio.Semaphore(placeholder._max_conc)
268
+ child_to_placeholder[child_inv.name] = pid
269
+ if child_inv.name not in dag.nodes:
270
+ dag.add_node(child_inv.name)
271
+ # Edge from all upstream of src's children or src itself
272
+ if isinstance(src, TaskInvocation):
273
+ dag.add_edge(src.name, child_inv.name)
274
+ # Source already finished, so dependency is satisfied immediately
275
+ remaining_deps[child_inv.name] = set()
276
+ ready.append(child_inv.name)
277
+ else:
278
+ # depend on all children of src
279
+ deps = {c.name for c in src}
280
+ for d in deps:
281
+ dag.add_edge(d, child_inv.name)
282
+ # All deps completed (readiness condition); mark none remaining
283
+ remaining_deps[child_inv.name] = set()
284
+ ready.append(child_inv.name)
285
+ # ensure inv_map aware
286
+ if child_inv.name not in inv_map:
287
+ inv_map[child_inv.name] = child_inv
288
+ # Update consumers
289
+ for consumer in invocations:
290
+ if consumer is src:
291
+ continue
292
+ replaced = False
293
+
294
+ def _walk(o, target=placeholder): # bind loop var
295
+ nonlocal replaced
296
+ if o is target:
297
+ replaced = True
298
+ return list(target)
299
+ # Do not traverse into unrelated DynamicFanOut placeholders;
300
+ # treat them as atomic until their own expansion pass.
301
+ from .fanout import DynamicFanOut as DynamicFanOutPlaceholder
302
+
303
+ if isinstance(o, DynamicFanOutPlaceholder):
304
+ return o
305
+ if isinstance(o, list):
306
+ return [_walk(x, target) for x in o]
307
+ if isinstance(o, tuple):
308
+ return tuple(_walk(x, target) for x in o)
309
+ if isinstance(o, dict):
310
+ return {k: _walk(v, target) for k, v in o.items()}
311
+ return o
312
+
313
+ consumer.args = (
314
+ _walk(list(consumer.args)) if consumer.args else consumer.args
315
+ )
316
+ consumer.kwargs = (
317
+ _walk(consumer.kwargs) if consumer.kwargs else consumer.kwargs
318
+ )
319
+ if replaced:
320
+ for child_inv in placeholder:
321
+ dag.add_edge(child_inv.name, consumer.name)
322
+ remaining_deps.setdefault(consumer.name, set()).add(child_inv.name)
323
+ # Placeholders remain for nested detection; do not delete
324
+ # update downstream readiness
325
+ # Use set() to guard against duplicate edges causing repeated scheduling attempts
326
+ for child in set(dag.nodes[fname].downstream):
327
+ remaining_deps[child].discard(fname)
328
+ # Skip scheduling if any upstream failed and policy is not CONTINUE
329
+ upstream_failed = any(
330
+ isinstance(results.get(up), TaskExecutionError)
331
+ for up in dag.nodes[child].upstream
332
+ if up in results
333
+ )
334
+ if upstream_failed and failure_policy != FailurePolicy.CONTINUE:
335
+ results[child] = TaskExecutionError(child, RuntimeError("Upstream failed"))
336
+ continue
337
+ if (
338
+ not remaining_deps[child]
339
+ and child not in ready
340
+ and child not in pending_tasks
341
+ and child not in results
342
+ ):
343
+ ready.append(child)
344
+ if failure_policy == FailurePolicy.AGGREGATE and task_errors:
345
+ raise AggregateTaskError(task_errors)
346
+ return results
347
+
348
+
349
+ def _hydrate(struct: Any, results: dict[str, Any]) -> Any:
350
+ from .build import TaskInvocation
351
+
352
+ if isinstance(struct, TaskInvocation):
353
+ return results[struct.name]
354
+ if isinstance(struct, list):
355
+ return [_hydrate(s, results) for s in struct]
356
+ if isinstance(struct, tuple):
357
+ return tuple(_hydrate(s, results) for s in struct)
358
+ if isinstance(struct, set):
359
+ return {_hydrate(s, results) for s in struct}
360
+ if isinstance(struct, dict):
361
+ return {k: _hydrate(v, results) for k, v in struct.items()}
362
+ return struct
@@ -0,0 +1,45 @@
1
+ """Secrets provider abstraction."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from typing import Protocol
7
+
8
+
9
+ class SecretsProvider(Protocol): # pragma: no cover - interface
10
+ def get(self, key: str) -> str | None: ...
11
+
12
+
13
+ class EnvSecrets(SecretsProvider):
14
+ def get(self, key: str) -> str | None:
15
+ return os.environ.get(key)
16
+
17
+
18
+ class StaticMappingSecrets(SecretsProvider):
19
+ def __init__(self, data: dict[str, str]):
20
+ self.data = data
21
+
22
+ def get(self, key: str) -> str | None:
23
+ return self.data.get(key)
24
+
25
+
26
+ class DummyVaultSecrets(SecretsProvider): # placeholder for future HashiCorp Vault integration
27
+ def __init__(self, prefix: str = ""):
28
+ self.prefix = prefix
29
+ # In a real implementation, store client/session
30
+
31
+ def get(self, key: str) -> str | None: # pragma: no cover - placeholder
32
+ # Would query Vault; here just environment fallback with prefix
33
+ return os.environ.get(self.prefix + key)
34
+
35
+
36
+ _provider: SecretsProvider = EnvSecrets()
37
+
38
+
39
+ def set_secrets_provider(p: SecretsProvider) -> None:
40
+ global _provider
41
+ _provider = p
42
+
43
+
44
+ def secret(key: str) -> str | None:
45
+ return _provider.get(key)
auto_workflow/task.py ADDED
@@ -0,0 +1,158 @@
1
+ """Task abstraction and decorator."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import builtins
7
+ import functools
8
+ import inspect
9
+ from collections.abc import Callable
10
+ from dataclasses import dataclass, field
11
+ from typing import Any
12
+
13
+ from .build import current_build_context
14
+ from .events import emit
15
+ from .exceptions import RetryExhaustedError, TimeoutError
16
+ from .tracing import get_tracer
17
+ from .types import TaskFn
18
+ from .utils import default_cache_key, maybe_await
19
+
20
+
21
+ @dataclass(slots=True)
22
+ class TaskDefinition:
23
+ name: str
24
+ fn: TaskFn
25
+ original_fn: TaskFn | None = None
26
+ retries: int = 0
27
+ retry_backoff: float = 0.0
28
+ retry_jitter: float = 0.0
29
+ timeout: float | None = None
30
+ cache_ttl: int | None = None
31
+ cache_key_fn: Callable[..., str] = default_cache_key
32
+ tags: set[str] = field(default_factory=set)
33
+ run_in: str = "async" # one of: async, thread, process
34
+ persist: bool = False # store large result via artifact store (future)
35
+ priority: int = 0 # higher runs earlier when ready
36
+
37
+ async def run(self, *args: Any, **kwargs: Any) -> Any:
38
+ attempt = 0
39
+ while True:
40
+
41
+ async def _invoke():
42
+ tracer = get_tracer()
43
+ async with tracer.span(f"task:{self.name}"):
44
+ if self.run_in == "async":
45
+ return await maybe_await(self.fn(*args, **kwargs))
46
+ if self.run_in == "thread":
47
+ return await asyncio.to_thread(self.fn, *args, **kwargs)
48
+ if self.run_in == "process":
49
+ from .execution import get_process_pool
50
+
51
+ loop = asyncio.get_running_loop()
52
+ pool = get_process_pool()
53
+ import cloudpickle
54
+
55
+ from .execution import run_pickled
56
+
57
+ fn_bytes = cloudpickle.dumps((self.fn, args, kwargs))
58
+ return await loop.run_in_executor(
59
+ pool, functools.partial(run_pickled, fn_bytes)
60
+ )
61
+ raise RuntimeError(f"Unknown run_in mode: {self.run_in}")
62
+
63
+ try:
64
+ if self.timeout:
65
+ return await asyncio.wait_for(_invoke(), timeout=self.timeout)
66
+ return await _invoke()
67
+ except builtins.TimeoutError as te:
68
+ err: Exception = TimeoutError(self.name, te)
69
+ except Exception as e: # noqa: BLE001
70
+ err = e
71
+ # error path
72
+ if attempt < self.retries:
73
+ attempt += 1
74
+ emit("task_retry", {"task": self.name, "attempt": attempt, "max": self.retries})
75
+ sleep_dur = self.retry_backoff * (2 ** (attempt - 1))
76
+ if self.retry_jitter:
77
+ import random
78
+
79
+ sleep_dur += random.uniform(0, self.retry_jitter)
80
+ await asyncio.sleep(sleep_dur)
81
+ continue
82
+ if isinstance(err, TimeoutError):
83
+ raise err
84
+ raise RetryExhaustedError(self.name, err) from err
85
+
86
+ def cache_key(self, *args: Any, **kwargs: Any) -> str:
87
+ return self.cache_key_fn(self.fn, args, kwargs)
88
+
89
+ def __call__(self, *args: Any, **kwargs: Any) -> Any: # building vs immediate execution
90
+ ctx = current_build_context()
91
+ if ctx is None:
92
+ # immediate execution outside a flow build (synchronous helper)
93
+ tracer = get_tracer()
94
+
95
+ async def _run():
96
+ async with tracer.span(f"task:{self.name}"):
97
+ val = await self.run(*args, **kwargs)
98
+ if self.persist:
99
+ from .artifacts import get_store
100
+
101
+ store = get_store()
102
+ ref = store.put(val)
103
+ return ref
104
+ return val
105
+
106
+ return asyncio.run(_run())
107
+ return ctx.register(self.name, self.fn, args, kwargs, self)
108
+
109
+
110
+ def task(
111
+ _fn: TaskFn | None = None,
112
+ *,
113
+ name: str | None = None,
114
+ retries: int = 0,
115
+ retry_backoff: float = 0.0,
116
+ retry_jitter: float = 0.0,
117
+ timeout: float | None = None,
118
+ cache_ttl: int | None = None,
119
+ cache_key_fn: Callable[..., str] = default_cache_key,
120
+ tags: set[str] | None = None,
121
+ run_in: str | None = None,
122
+ persist: bool = False,
123
+ priority: int = 0,
124
+ ) -> Callable[[TaskFn], TaskDefinition]:
125
+ """Decorator to define a task.
126
+
127
+ Auto executor selection:
128
+ - If run_in explicitly provided -> honored.
129
+ - If function is an async def -> defaults to "async".
130
+ - Else (synchronous callable) -> defaults to "thread" to avoid blocking the event loop.
131
+ """
132
+
133
+ def wrap(fn: TaskFn) -> TaskDefinition:
134
+ inferred_run_in = run_in
135
+ if inferred_run_in is None:
136
+ inferred_run_in = (
137
+ "async" if inspect.iscoroutinefunction(fn) else "thread"
138
+ ) # safe default for blocking sync functions
139
+ td_obj = TaskDefinition(
140
+ name=name or fn.__name__,
141
+ fn=fn,
142
+ original_fn=fn,
143
+ retries=retries,
144
+ retry_backoff=retry_backoff,
145
+ retry_jitter=retry_jitter,
146
+ timeout=timeout,
147
+ cache_ttl=cache_ttl,
148
+ cache_key_fn=cache_key_fn,
149
+ tags=tags or set(),
150
+ run_in=inferred_run_in,
151
+ persist=persist,
152
+ priority=priority,
153
+ )
154
+ return td_obj
155
+
156
+ if _fn is not None:
157
+ return wrap(_fn)
158
+ return wrap
@@ -0,0 +1,29 @@
1
+ """Tracing scaffold (OpenTelemetry friendly)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import time
6
+ from contextlib import asynccontextmanager
7
+ from typing import Any
8
+
9
+
10
+ class DummyTracer:
11
+ @asynccontextmanager
12
+ async def span(self, name: str, **attrs: Any): # pragma: no cover simple scaffold
13
+ start = time.time()
14
+ try:
15
+ yield {"start": start, "name": name, **attrs}
16
+ finally:
17
+ _ = time.time() - start
18
+
19
+
20
+ _tracer: DummyTracer = DummyTracer()
21
+
22
+
23
+ def get_tracer() -> DummyTracer:
24
+ return _tracer
25
+
26
+
27
+ def set_tracer(t: DummyTracer) -> None:
28
+ global _tracer
29
+ _tracer = t
auto_workflow/types.py ADDED
@@ -0,0 +1,27 @@
1
+ """Common type aliases and Protocols."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Awaitable, Callable
6
+ from typing import Protocol, TypeVar, runtime_checkable
7
+
8
+ T = TypeVar("T")
9
+ R = TypeVar("R")
10
+
11
+ TaskFn = Callable[..., R] | Callable[..., Awaitable[R]]
12
+
13
+
14
+ @runtime_checkable
15
+ class SupportsHash(Protocol):
16
+ def __hash__(self) -> int: ... # pragma: no cover
17
+
18
+
19
+ CacheKey = str
20
+
21
+
22
+ class CancelledSentinel:
23
+ def __repr__(self) -> str: # pragma: no cover
24
+ return "<CANCELLED>"
25
+
26
+
27
+ CANCELLED = CancelledSentinel()