dagrun 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dagrun/__init__.py ADDED
@@ -0,0 +1,6 @@
1
+ from dagrun import model
2
+ from dagrun.control import Defer
3
+ from dagrun.dag import Dag, DagValidationError
4
+ from dagrun.dag_runner import DagRunner
5
+
6
+ __all__ = ["Dag", "DagRunner", "DagValidationError", "Defer", "model"]
dagrun/chain.py ADDED
@@ -0,0 +1,71 @@
1
+ from collections.abc import Callable, Mapping
2
+ from dataclasses import dataclass
3
+ from typing import Any, Protocol
4
+
5
+
6
+ def member(fn: Callable[..., Any]) -> Callable[..., Any]:
7
+ fn.__dagrun_member__ = True # type: ignore[attr-defined]
8
+ return fn
9
+
10
+
11
+ class Combine(Protocol):
12
+ def reduce(self, column: str, values: list[Any]) -> Any: ...
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class LastWins:
17
+ def reduce(self, column: str, values: list[Any]) -> Any:
18
+ return values[-1]
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class FirstWins:
23
+ def reduce(self, column: str, values: list[Any]) -> Any:
24
+ return values[0]
25
+
26
+
27
+ @dataclass(frozen=True)
28
+ class Error:
29
+ def reduce(self, column: str, values: list[Any]) -> Any:
30
+ distinct = {_hash(v) for v in values}
31
+ if len(distinct) > 1:
32
+ raise CombineConflict(column, values)
33
+ return values[0]
34
+
35
+
36
+ class CombineConflict(RuntimeError):
37
+ def __init__(self, column: str, values: list[Any]) -> None:
38
+ super().__init__(f"Conflicting values for column {column!r}: {values!r}")
39
+ self.column = column
40
+ self.values = values
41
+
42
+
43
+ def _hash(value: Any) -> Any:
44
+ try:
45
+ hash(value)
46
+ except TypeError:
47
+ return repr(value)
48
+ return value
49
+
50
+
51
+ @dataclass(frozen=True)
52
+ class PerColumn:
53
+ """Per-column combine override map. '*' means default for unlisted columns."""
54
+
55
+ rules: tuple[tuple[str, Combine], ...]
56
+
57
+ def __init__(self, rules: Mapping[str, Combine] | "tuple[tuple[str, Combine], ...]") -> None:
58
+ if isinstance(rules, Mapping):
59
+ items = tuple(rules.items())
60
+ else:
61
+ items = tuple(rules)
62
+ object.__setattr__(self, "rules", items)
63
+
64
+ def reduce(self, column: str, values: list[Any]) -> Any:
65
+ lookup = dict(self.rules)
66
+ rule = lookup.get(column, lookup.get("*"))
67
+ if rule is None:
68
+ return values[-1]
69
+ return rule.reduce(column, values)
70
+
71
+
dagrun/control.py ADDED
@@ -0,0 +1,9 @@
1
+ class Defer(Exception):
2
+ """Signal from a node that the current input should not be processed now.
3
+
4
+ When a node raises Defer the runner discards any output produced for that
5
+ input and records no invocation, so the input is evaluated again from
6
+ scratch on the next run rather than being treated as done. Use it for
7
+ decisions that depend on runtime configuration (e.g. a token budget)
8
+ rather than on an intrinsic, persistable property of the input.
9
+ """
dagrun/dag.py ADDED
@@ -0,0 +1,396 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ import types
5
+ import typing
6
+ from collections.abc import (
7
+ AsyncIterable,
8
+ AsyncIterator,
9
+ Callable,
10
+ Generator,
11
+ Iterable,
12
+ Iterator,
13
+ Mapping,
14
+ )
15
+ from dataclasses import dataclass, field
16
+ from typing import Any, Union, get_args, get_origin
17
+
18
+ from dagrun.chain import Combine, LastWins, PerColumn
19
+ from dagrun.model import (
20
+ Entity,
21
+ SYSTEM_COLUMNS,
22
+ _entity_specs,
23
+ _EntityView,
24
+ _find_pk,
25
+ _find_unique_fk_to,
26
+ _required_columns,
27
+ )
28
+ from dagrun.plan import CompiledChain, CompiledInput, CompiledMember, Plan
29
+
30
+
31
+ class DagValidationError(Exception):
32
+ pass
33
+
34
+
35
+ @dataclass
36
+ class _Chain:
37
+ members: list[Callable[..., Any]]
38
+ combine: Combine
39
+ labels: set[str] = field(default_factory=set)
40
+ cost: Mapping[str, int] = field(default_factory=dict)
41
+
42
+
43
+ _ITERABLE_ORIGINS: frozenset[Any] = frozenset(
44
+ {Iterable, Iterator, Generator, AsyncIterable, AsyncIterator, list, tuple, set, frozenset}
45
+ )
46
+
47
+
48
+ def _unwrap_input_param(
49
+ ann: Any,
50
+ ) -> tuple[type | None, tuple[str, ...], bool]:
51
+ """Like _unwrap_entity, but additionally reports whether the outermost
52
+ wrapper was a collection (list/Iterable/etc.). A collection annotation on
53
+ an input parameter signals an aggregate chain: dagrun groups child rows
54
+ by their FK to the produced parent entity and calls the fn once per
55
+ parent with all matching children."""
56
+ if ann is inspect.Parameter.empty or ann is inspect.Signature.empty:
57
+ return None, (), False
58
+ origin = get_origin(ann)
59
+ if origin in _ITERABLE_ORIGINS:
60
+ args = get_args(ann)
61
+ if not args:
62
+ return None, (), False
63
+ entity, cols = _unwrap_entity(args[0])
64
+ if entity is None:
65
+ return None, (), False
66
+ return entity, cols, True
67
+ entity, cols = _unwrap_entity(ann)
68
+ return entity, cols, False
69
+
70
+
71
+ def _unwrap_entity(ann: Any) -> tuple[type | None, tuple[str, ...]]:
72
+ """Peel Iterable/Union/Annotated wrappers and return (entity, columns).
73
+
74
+ Returns (None, ()) for annotations that don't resolve to an Entity subclass.
75
+ For a bare Entity, returns all required column names.
76
+ """
77
+ if ann is inspect.Parameter.empty or ann is inspect.Signature.empty:
78
+ return None, ()
79
+
80
+ metadata = getattr(ann, "__metadata__", None)
81
+ if metadata is not None:
82
+ for m in metadata:
83
+ if isinstance(m, _EntityView):
84
+ return m.entity, m.columns
85
+ return _unwrap_entity(ann.__origin__)
86
+
87
+ origin = get_origin(ann)
88
+ if origin is Union or origin is types.UnionType:
89
+ non_none = [a for a in get_args(ann) if a is not type(None)]
90
+ if len(non_none) == 1:
91
+ return _unwrap_entity(non_none[0])
92
+ return None, ()
93
+
94
+ if origin in _ITERABLE_ORIGINS:
95
+ args = get_args(ann)
96
+ if args:
97
+ return _unwrap_entity(args[0])
98
+ return None, ()
99
+
100
+ if isinstance(ann, type) and issubclass(ann, Entity):
101
+ return ann, _required_columns(ann)
102
+
103
+ return None, ()
104
+
105
+
106
+ def _detect_cycle(
107
+ edges: dict[CompiledChain, list[CompiledChain]],
108
+ ) -> list[CompiledChain] | None:
109
+ color: dict[CompiledChain, int] = {}
110
+ stack: list[CompiledChain] = []
111
+
112
+ def dfs(node: CompiledChain) -> list[CompiledChain] | None:
113
+ color[node] = 1
114
+ stack.append(node)
115
+ for nxt in edges.get(node, ()):
116
+ c = color.get(nxt, 0)
117
+ if c == 1:
118
+ i = stack.index(nxt)
119
+ return stack[i:] + [nxt]
120
+ if c == 0:
121
+ cyc = dfs(nxt)
122
+ if cyc is not None:
123
+ return cyc
124
+ stack.pop()
125
+ color[node] = 2
126
+ return None
127
+
128
+ for n in edges:
129
+ if color.get(n, 0) == 0:
130
+ cyc = dfs(n)
131
+ if cyc is not None:
132
+ return cyc
133
+ return None
134
+
135
+
136
+ class Dag:
137
+ def __init__(self) -> None:
138
+ self._chains: list[_Chain] = []
139
+
140
+ def fn(
141
+ self,
142
+ target: Any = None,
143
+ *,
144
+ labels: set[str] | None = None,
145
+ cost: Mapping[str, int] | None = None,
146
+ ) -> Any:
147
+ return self._register(
148
+ target, labels, combine=LastWins(), is_group=False, cost=cost
149
+ )
150
+
151
+ def group(
152
+ self,
153
+ target: Any = None,
154
+ *,
155
+ labels: set[str] | None = None,
156
+ combine: Combine | Mapping[str, Combine] | None = None,
157
+ cost: Mapping[str, int] | None = None,
158
+ ) -> Any:
159
+ resolved = combine
160
+ if isinstance(resolved, Mapping):
161
+ resolved = PerColumn(rules=dict(resolved))
162
+ return self._register(
163
+ target, labels, combine=resolved or LastWins(), is_group=True, cost=cost
164
+ )
165
+
166
+ def compile(self) -> Plan:
167
+ parsed: list[_ParsedChain] = [_parse_chain(raw) for raw in self._chains]
168
+
169
+ compiled_chains: list[CompiledChain] = []
170
+ parsed_to_compiled: dict[int, CompiledChain] = {}
171
+ for p in parsed:
172
+ labels = set(p.raw.labels)
173
+ if p.produces is not None:
174
+ labels.add(p.produces.__name__.lower())
175
+ for e, _cols, is_coll in p.inputs:
176
+ if not is_coll:
177
+ continue
178
+ if p.produces is None:
179
+ raise DagValidationError(
180
+ f"Aggregate chain reading {e.__name__} must produce "
181
+ f"an entity (none declared)."
182
+ )
183
+ try:
184
+ _find_unique_fk_to(e, p.produces)
185
+ except LookupError as exc:
186
+ raise DagValidationError(str(exc)) from exc
187
+
188
+ compiled = CompiledChain(
189
+ produces=p.produces if p.produces is not None else object,
190
+ inputs=tuple(
191
+ CompiledInput(entity=e, columns=cols, is_collection=is_coll)
192
+ for e, cols, is_coll in p.inputs
193
+ ),
194
+ members=tuple(p.members),
195
+ combine=p.raw.combine,
196
+ labels=frozenset(labels),
197
+ cost=dict(p.raw.cost),
198
+ )
199
+ compiled_chains.append(compiled)
200
+ parsed_to_compiled[id(p)] = compiled
201
+
202
+ column_owners: dict[tuple[type, str], CompiledChain] = {}
203
+ column_kinds: dict[tuple[type, str], type] = {}
204
+
205
+ for p in parsed:
206
+ if p.produces is None:
207
+ continue
208
+ specs = _entity_specs(p.produces)
209
+ compiled = parsed_to_compiled[id(p)]
210
+ for col in p.produces_cols:
211
+ if col in SYSTEM_COLUMNS:
212
+ continue
213
+ key = (p.produces, col)
214
+ if key in column_owners:
215
+ raise DagValidationError(
216
+ f"Column {p.produces.__name__}.{col} produced by multiple chains"
217
+ )
218
+ column_owners[key] = compiled
219
+ if col in specs:
220
+ column_kinds[key] = specs[col].kind
221
+
222
+ edges: dict[CompiledChain, list[CompiledChain]] = {c: [] for c in compiled_chains}
223
+ for p in parsed:
224
+ compiled = parsed_to_compiled[id(p)]
225
+ for m in p.members:
226
+ for entity, cols in m.needs_entities:
227
+ needed_raw = list(cols) if cols else [_find_pk(entity)]
228
+ needed = [c for c in needed_raw if c not in SYSTEM_COLUMNS] or [
229
+ _find_pk(entity)
230
+ ]
231
+ for col in needed:
232
+ owner = column_owners.get((entity, col))
233
+ if owner is None:
234
+ raise DagValidationError(
235
+ f"No chain produces {entity.__name__}.{col} "
236
+ f"(required by {_chain_label(p)})"
237
+ )
238
+ if owner is compiled:
239
+ continue
240
+ if compiled not in edges[owner]:
241
+ edges[owner].append(compiled)
242
+
243
+ cycle = _detect_cycle(edges)
244
+ if cycle is not None:
245
+ names = " -> ".join(c.produces.__name__ for c in cycle)
246
+ raise DagValidationError(f"Cycle detected: {names}")
247
+
248
+ return Plan(
249
+ chains=tuple(compiled_chains),
250
+ column_owners=column_owners,
251
+ column_kinds=column_kinds,
252
+ edges={k: tuple(v) for k, v in edges.items()},
253
+ )
254
+
255
+ def _register(
256
+ self,
257
+ target: Any,
258
+ labels: set[str] | None,
259
+ combine: Combine,
260
+ is_group: bool,
261
+ cost: Mapping[str, int] | None = None,
262
+ ) -> Any:
263
+ def register(t: Any) -> Any:
264
+ if is_group:
265
+ if not isinstance(t, type):
266
+ raise DagValidationError(
267
+ "@dag.group expects a class with @member functions; "
268
+ "use @dag.fn for a solo function"
269
+ )
270
+ members = [
271
+ v
272
+ for v in vars(t).values()
273
+ if callable(v) and getattr(v, "__dagrun_member__", False)
274
+ ]
275
+ if not members:
276
+ raise DagValidationError(
277
+ f"@dag.group {t.__name__} has no @member functions"
278
+ )
279
+ else:
280
+ if isinstance(t, type):
281
+ raise DagValidationError(
282
+ "@dag.fn expects a function; use @dag.group for a class with members"
283
+ )
284
+ members = [t]
285
+ auto_labels = {t.__name__}
286
+ merged_cost: dict[str, int] = {"concurrency": 1}
287
+ if cost is not None:
288
+ merged_cost.update(cost)
289
+ self._chains.append(
290
+ _Chain(
291
+ members=members,
292
+ combine=combine,
293
+ labels=auto_labels | (labels or set()),
294
+ cost=merged_cost,
295
+ )
296
+ )
297
+ return t
298
+
299
+ if target is None:
300
+ return register
301
+ return register(target)
302
+
303
+
304
+ @dataclass
305
+ class _ParsedChain:
306
+ raw: _Chain
307
+ produces: type | None
308
+ produces_cols: frozenset[str]
309
+ inputs: tuple[tuple[type, frozenset[str], bool], ...]
310
+ members: list[CompiledMember]
311
+
312
+
313
+ def _chain_label(p: _ParsedChain) -> str:
314
+ fn = p.raw.members[0] if p.raw.members else None
315
+ return getattr(fn, "__qualname__", repr(p.raw))
316
+
317
+
318
+ def _parse_chain(raw: _Chain) -> _ParsedChain:
319
+ members: list[CompiledMember] = []
320
+ produces: type | None = None
321
+ produces_cols_union: set[str] = set()
322
+ chain_inputs: tuple[tuple[type, frozenset[str], bool], ...] | None = None
323
+ any_member_has_same_entity_input = False
324
+
325
+ for fn in raw.members:
326
+ try:
327
+ hints = typing.get_type_hints(fn, include_extras=True)
328
+ except Exception:
329
+ hints = {}
330
+ sig = inspect.signature(fn)
331
+
332
+ needs_entities: list[tuple[type, frozenset[str], bool]] = []
333
+ needs_providers: list[type] = []
334
+ for pname, param in sig.parameters.items():
335
+ ann = hints.get(pname, param.annotation)
336
+ entity, cols, is_collection = _unwrap_input_param(ann)
337
+ if entity is not None:
338
+ needs_entities.append((entity, frozenset(cols), is_collection))
339
+ elif ann is not inspect.Parameter.empty and isinstance(ann, type):
340
+ needs_providers.append(ann)
341
+
342
+ ret_ann = hints.get("return", sig.return_annotation)
343
+ ret_entity, ret_cols = _unwrap_entity(ret_ann)
344
+
345
+ if ret_entity is not None:
346
+ if produces is None:
347
+ produces = ret_entity
348
+ elif produces is not ret_entity:
349
+ raise DagValidationError(
350
+ f"Members of chain disagree on produced entity: "
351
+ f"{produces.__name__} vs {ret_entity.__name__}"
352
+ )
353
+
354
+ if ret_entity is not None and any(
355
+ e is ret_entity for e, _, _ in needs_entities
356
+ ):
357
+ any_member_has_same_entity_input = True
358
+
359
+ if ret_entity is not None and any(
360
+ is_coll for _, _, is_coll in needs_entities
361
+ ):
362
+ any_member_has_same_entity_input = True
363
+
364
+ produced_cols = frozenset(ret_cols)
365
+ produces_cols_union |= set(ret_cols)
366
+
367
+ member_inputs = tuple(needs_entities)
368
+ if chain_inputs is None:
369
+ chain_inputs = member_inputs
370
+ elif {e for e, _, _ in chain_inputs} != {e for e, _, _ in member_inputs}:
371
+ raise DagValidationError(
372
+ "Members of group disagree on input entities"
373
+ )
374
+
375
+ members.append(
376
+ CompiledMember(
377
+ fn=fn,
378
+ needs_entities=tuple(
379
+ (e, c) for e, c, _ in needs_entities
380
+ ),
381
+ needs_providers=tuple(needs_providers),
382
+ produces_columns=produced_cols,
383
+ )
384
+ )
385
+
386
+ if produces is not None and not any_member_has_same_entity_input:
387
+ pk_col = _find_pk(produces)
388
+ produces_cols_union.add(pk_col)
389
+
390
+ return _ParsedChain(
391
+ raw=raw,
392
+ produces=produces,
393
+ produces_cols=frozenset(produces_cols_union),
394
+ inputs=chain_inputs or (),
395
+ members=members,
396
+ )
dagrun/dag_runner.py ADDED
@@ -0,0 +1,59 @@
1
+ from collections.abc import Callable, Iterable, Mapping
2
+ from datetime import timedelta
3
+ from typing import Any
4
+
5
+ from dagrun.dag import Dag
6
+ from dagrun.events import Observer
7
+ from dagrun.model import _internalize_columns
8
+ from dagrun.store import MemoryStore, Store
9
+ from dagrun.strategy import SingleThread, Strategy
10
+
11
+
12
+ class DagRunner:
13
+ def __init__(
14
+ self,
15
+ store: Store | None = None,
16
+ *,
17
+ strategy: Strategy | None = None,
18
+ observers: list[Observer] | None = None,
19
+ ) -> None:
20
+ self._store: Store = store if store is not None else MemoryStore()
21
+ self._strategy: Strategy = strategy if strategy is not None else SingleThread()
22
+ self._observers: list[Observer] = list(observers) if observers is not None else []
23
+ self._providers: dict[type, Callable[[], Any]] = {}
24
+ self._external_backends: dict[type, Any] = {}
25
+
26
+ def provide(self, kind: type, factory: Callable[[], Any]) -> None:
27
+ self._providers[kind] = factory
28
+
29
+ def register(self, column_kind: type, backend: Any) -> None:
30
+ self._external_backends[column_kind] = backend
31
+
32
+ def add_observer(self, observer: Observer) -> None:
33
+ self._observers.append(observer)
34
+
35
+ def execute(
36
+ self,
37
+ dag: Dag,
38
+ *,
39
+ labels: set[str] | None = None,
40
+ max_age: timedelta | None = None,
41
+ pools: Mapping[str, int] | None = None,
42
+ ) -> None:
43
+ plan = dag.compile()
44
+ if labels is not None:
45
+ plan = plan.filter(labels=labels)
46
+ self._strategy.run(
47
+ plan,
48
+ store=self._store,
49
+ external_backends=self._external_backends,
50
+ providers=self._providers,
51
+ observers=self._observers,
52
+ max_age=max_age,
53
+ pools=pools,
54
+ )
55
+
56
+ def get_entities(self, entity: type) -> Iterable[Any]:
57
+ for row in self._store.get_entities(entity):
58
+ hydrated = _internalize_columns(entity, row, self._external_backends)
59
+ yield entity(**hydrated)
dagrun/events.py ADDED
@@ -0,0 +1,96 @@
1
+ from collections.abc import Mapping
2
+ from dataclasses import dataclass
3
+ from typing import Any, Protocol
4
+
5
+
6
+ @dataclass(frozen=True, kw_only=True)
7
+ class Event:
8
+ timestamp: float
9
+
10
+
11
+ @dataclass(frozen=True, kw_only=True)
12
+ class ChainStarted(Event):
13
+ chain: Any
14
+
15
+
16
+ @dataclass(frozen=True, kw_only=True)
17
+ class ChainFinished(Event):
18
+ chain: Any
19
+
20
+
21
+ @dataclass(frozen=True, kw_only=True)
22
+ class MemberStarted(Event):
23
+ chain: Any
24
+ member: Any
25
+ input_pk: Any
26
+
27
+
28
+ @dataclass(frozen=True, kw_only=True)
29
+ class MemberFinished(Event):
30
+ chain: Any
31
+ member: Any
32
+ input_pk: Any
33
+ result: Any
34
+
35
+
36
+ @dataclass(frozen=True, kw_only=True)
37
+ class MemberFailed(Event):
38
+ chain: Any
39
+ member: Any
40
+ input_pk: Any
41
+ error: BaseException
42
+
43
+
44
+ @dataclass(frozen=True, kw_only=True)
45
+ class RowWritten(Event):
46
+ entity: type
47
+ pk: Any
48
+ columns: frozenset[str]
49
+
50
+
51
+ @dataclass(frozen=True, kw_only=True)
52
+ class RunStarted(Event):
53
+ pass
54
+
55
+
56
+ @dataclass(frozen=True, kw_only=True)
57
+ class PoolsRegistered(Event):
58
+ """Published once at the start of a run with the live pool gauges.
59
+
60
+ `capacity` is the per-pool token cap; `available` is the live remaining-token
61
+ mapping the scheduler mutates as tasks acquire and release tokens. Both are
62
+ references to the scheduler's own dicts so an observer can poll utilization
63
+ without per-task events. Pool keys are fixed at startup, so reading these
64
+ without the scheduler lock is safe for display purposes.
65
+ """
66
+
67
+ capacity: Mapping[str, int]
68
+ available: Mapping[str, int]
69
+
70
+
71
+ @dataclass(frozen=True, kw_only=True)
72
+ class RunFinished(Event):
73
+ error: BaseException | None = None
74
+
75
+
76
+ @dataclass(frozen=True, kw_only=True)
77
+ class TaskEnqueued(Event):
78
+ """A (chain, input_pk) task was scheduled. Grows a chain's denominator."""
79
+
80
+ chain: Any
81
+
82
+
83
+ @dataclass(frozen=True, kw_only=True)
84
+ class TaskCompleted(Event):
85
+ """A task finished, either by running or by being skipped as fresh.
86
+
87
+ `skipped` is True when max_age suppressed the run (the work was done in a
88
+ prior run and is current in the store). Both cases grow a chain's numerator.
89
+ """
90
+
91
+ chain: Any
92
+ skipped: bool
93
+
94
+
95
+ class Observer(Protocol):
96
+ def on_event(self, event: Event) -> None: ...
dagrun/ext/__init__.py ADDED
File without changes