dagrun 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dagrun/__init__.py +6 -0
- dagrun/chain.py +71 -0
- dagrun/control.py +9 -0
- dagrun/dag.py +396 -0
- dagrun/dag_runner.py +59 -0
- dagrun/events.py +96 -0
- dagrun/ext/__init__.py +0 -0
- dagrun/ext/blob.py +80 -0
- dagrun/ext/progress.py +277 -0
- dagrun/ext/sqlite.py +326 -0
- dagrun/model.py +272 -0
- dagrun/plan.py +67 -0
- dagrun/py.typed +0 -0
- dagrun/store.py +88 -0
- dagrun/strategy.py +778 -0
- dagrun-0.1.0.dist-info/METADATA +87 -0
- dagrun-0.1.0.dist-info/RECORD +19 -0
- dagrun-0.1.0.dist-info/WHEEL +4 -0
- dagrun-0.1.0.dist-info/licenses/LICENSE +21 -0
dagrun/__init__.py
ADDED
dagrun/chain.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
from collections.abc import Callable, Mapping
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Protocol
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def member(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
7
|
+
fn.__dagrun_member__ = True # type: ignore[attr-defined]
|
|
8
|
+
return fn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Combine(Protocol):
|
|
12
|
+
def reduce(self, column: str, values: list[Any]) -> Any: ...
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass(frozen=True)
|
|
16
|
+
class LastWins:
|
|
17
|
+
def reduce(self, column: str, values: list[Any]) -> Any:
|
|
18
|
+
return values[-1]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True)
|
|
22
|
+
class FirstWins:
|
|
23
|
+
def reduce(self, column: str, values: list[Any]) -> Any:
|
|
24
|
+
return values[0]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass(frozen=True)
|
|
28
|
+
class Error:
|
|
29
|
+
def reduce(self, column: str, values: list[Any]) -> Any:
|
|
30
|
+
distinct = {_hash(v) for v in values}
|
|
31
|
+
if len(distinct) > 1:
|
|
32
|
+
raise CombineConflict(column, values)
|
|
33
|
+
return values[0]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class CombineConflict(RuntimeError):
|
|
37
|
+
def __init__(self, column: str, values: list[Any]) -> None:
|
|
38
|
+
super().__init__(f"Conflicting values for column {column!r}: {values!r}")
|
|
39
|
+
self.column = column
|
|
40
|
+
self.values = values
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _hash(value: Any) -> Any:
|
|
44
|
+
try:
|
|
45
|
+
hash(value)
|
|
46
|
+
except TypeError:
|
|
47
|
+
return repr(value)
|
|
48
|
+
return value
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True)
|
|
52
|
+
class PerColumn:
|
|
53
|
+
"""Per-column combine override map. '*' means default for unlisted columns."""
|
|
54
|
+
|
|
55
|
+
rules: tuple[tuple[str, Combine], ...]
|
|
56
|
+
|
|
57
|
+
def __init__(self, rules: Mapping[str, Combine] | "tuple[tuple[str, Combine], ...]") -> None:
|
|
58
|
+
if isinstance(rules, Mapping):
|
|
59
|
+
items = tuple(rules.items())
|
|
60
|
+
else:
|
|
61
|
+
items = tuple(rules)
|
|
62
|
+
object.__setattr__(self, "rules", items)
|
|
63
|
+
|
|
64
|
+
def reduce(self, column: str, values: list[Any]) -> Any:
|
|
65
|
+
lookup = dict(self.rules)
|
|
66
|
+
rule = lookup.get(column, lookup.get("*"))
|
|
67
|
+
if rule is None:
|
|
68
|
+
return values[-1]
|
|
69
|
+
return rule.reduce(column, values)
|
|
70
|
+
|
|
71
|
+
|
dagrun/control.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
class Defer(Exception):
|
|
2
|
+
"""Signal from a node that the current input should not be processed now.
|
|
3
|
+
|
|
4
|
+
When a node raises Defer the runner discards any output produced for that
|
|
5
|
+
input and records no invocation, so the input is evaluated again from
|
|
6
|
+
scratch on the next run rather than being treated as done. Use it for
|
|
7
|
+
decisions that depend on runtime configuration (e.g. a token budget)
|
|
8
|
+
rather than on an intrinsic, persistable property of the input.
|
|
9
|
+
"""
|
dagrun/dag.py
ADDED
|
@@ -0,0 +1,396 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import types
|
|
5
|
+
import typing
|
|
6
|
+
from collections.abc import (
|
|
7
|
+
AsyncIterable,
|
|
8
|
+
AsyncIterator,
|
|
9
|
+
Callable,
|
|
10
|
+
Generator,
|
|
11
|
+
Iterable,
|
|
12
|
+
Iterator,
|
|
13
|
+
Mapping,
|
|
14
|
+
)
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from typing import Any, Union, get_args, get_origin
|
|
17
|
+
|
|
18
|
+
from dagrun.chain import Combine, LastWins, PerColumn
|
|
19
|
+
from dagrun.model import (
|
|
20
|
+
Entity,
|
|
21
|
+
SYSTEM_COLUMNS,
|
|
22
|
+
_entity_specs,
|
|
23
|
+
_EntityView,
|
|
24
|
+
_find_pk,
|
|
25
|
+
_find_unique_fk_to,
|
|
26
|
+
_required_columns,
|
|
27
|
+
)
|
|
28
|
+
from dagrun.plan import CompiledChain, CompiledInput, CompiledMember, Plan
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class DagValidationError(Exception):
|
|
32
|
+
pass
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class _Chain:
|
|
37
|
+
members: list[Callable[..., Any]]
|
|
38
|
+
combine: Combine
|
|
39
|
+
labels: set[str] = field(default_factory=set)
|
|
40
|
+
cost: Mapping[str, int] = field(default_factory=dict)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
_ITERABLE_ORIGINS: frozenset[Any] = frozenset(
|
|
44
|
+
{Iterable, Iterator, Generator, AsyncIterable, AsyncIterator, list, tuple, set, frozenset}
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def _unwrap_input_param(
|
|
49
|
+
ann: Any,
|
|
50
|
+
) -> tuple[type | None, tuple[str, ...], bool]:
|
|
51
|
+
"""Like _unwrap_entity, but additionally reports whether the outermost
|
|
52
|
+
wrapper was a collection (list/Iterable/etc.). A collection annotation on
|
|
53
|
+
an input parameter signals an aggregate chain: dagrun groups child rows
|
|
54
|
+
by their FK to the produced parent entity and calls the fn once per
|
|
55
|
+
parent with all matching children."""
|
|
56
|
+
if ann is inspect.Parameter.empty or ann is inspect.Signature.empty:
|
|
57
|
+
return None, (), False
|
|
58
|
+
origin = get_origin(ann)
|
|
59
|
+
if origin in _ITERABLE_ORIGINS:
|
|
60
|
+
args = get_args(ann)
|
|
61
|
+
if not args:
|
|
62
|
+
return None, (), False
|
|
63
|
+
entity, cols = _unwrap_entity(args[0])
|
|
64
|
+
if entity is None:
|
|
65
|
+
return None, (), False
|
|
66
|
+
return entity, cols, True
|
|
67
|
+
entity, cols = _unwrap_entity(ann)
|
|
68
|
+
return entity, cols, False
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _unwrap_entity(ann: Any) -> tuple[type | None, tuple[str, ...]]:
|
|
72
|
+
"""Peel Iterable/Union/Annotated wrappers and return (entity, columns).
|
|
73
|
+
|
|
74
|
+
Returns (None, ()) for annotations that don't resolve to an Entity subclass.
|
|
75
|
+
For a bare Entity, returns all required column names.
|
|
76
|
+
"""
|
|
77
|
+
if ann is inspect.Parameter.empty or ann is inspect.Signature.empty:
|
|
78
|
+
return None, ()
|
|
79
|
+
|
|
80
|
+
metadata = getattr(ann, "__metadata__", None)
|
|
81
|
+
if metadata is not None:
|
|
82
|
+
for m in metadata:
|
|
83
|
+
if isinstance(m, _EntityView):
|
|
84
|
+
return m.entity, m.columns
|
|
85
|
+
return _unwrap_entity(ann.__origin__)
|
|
86
|
+
|
|
87
|
+
origin = get_origin(ann)
|
|
88
|
+
if origin is Union or origin is types.UnionType:
|
|
89
|
+
non_none = [a for a in get_args(ann) if a is not type(None)]
|
|
90
|
+
if len(non_none) == 1:
|
|
91
|
+
return _unwrap_entity(non_none[0])
|
|
92
|
+
return None, ()
|
|
93
|
+
|
|
94
|
+
if origin in _ITERABLE_ORIGINS:
|
|
95
|
+
args = get_args(ann)
|
|
96
|
+
if args:
|
|
97
|
+
return _unwrap_entity(args[0])
|
|
98
|
+
return None, ()
|
|
99
|
+
|
|
100
|
+
if isinstance(ann, type) and issubclass(ann, Entity):
|
|
101
|
+
return ann, _required_columns(ann)
|
|
102
|
+
|
|
103
|
+
return None, ()
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _detect_cycle(
|
|
107
|
+
edges: dict[CompiledChain, list[CompiledChain]],
|
|
108
|
+
) -> list[CompiledChain] | None:
|
|
109
|
+
color: dict[CompiledChain, int] = {}
|
|
110
|
+
stack: list[CompiledChain] = []
|
|
111
|
+
|
|
112
|
+
def dfs(node: CompiledChain) -> list[CompiledChain] | None:
|
|
113
|
+
color[node] = 1
|
|
114
|
+
stack.append(node)
|
|
115
|
+
for nxt in edges.get(node, ()):
|
|
116
|
+
c = color.get(nxt, 0)
|
|
117
|
+
if c == 1:
|
|
118
|
+
i = stack.index(nxt)
|
|
119
|
+
return stack[i:] + [nxt]
|
|
120
|
+
if c == 0:
|
|
121
|
+
cyc = dfs(nxt)
|
|
122
|
+
if cyc is not None:
|
|
123
|
+
return cyc
|
|
124
|
+
stack.pop()
|
|
125
|
+
color[node] = 2
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
for n in edges:
|
|
129
|
+
if color.get(n, 0) == 0:
|
|
130
|
+
cyc = dfs(n)
|
|
131
|
+
if cyc is not None:
|
|
132
|
+
return cyc
|
|
133
|
+
return None
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class Dag:
|
|
137
|
+
def __init__(self) -> None:
|
|
138
|
+
self._chains: list[_Chain] = []
|
|
139
|
+
|
|
140
|
+
def fn(
|
|
141
|
+
self,
|
|
142
|
+
target: Any = None,
|
|
143
|
+
*,
|
|
144
|
+
labels: set[str] | None = None,
|
|
145
|
+
cost: Mapping[str, int] | None = None,
|
|
146
|
+
) -> Any:
|
|
147
|
+
return self._register(
|
|
148
|
+
target, labels, combine=LastWins(), is_group=False, cost=cost
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
def group(
|
|
152
|
+
self,
|
|
153
|
+
target: Any = None,
|
|
154
|
+
*,
|
|
155
|
+
labels: set[str] | None = None,
|
|
156
|
+
combine: Combine | Mapping[str, Combine] | None = None,
|
|
157
|
+
cost: Mapping[str, int] | None = None,
|
|
158
|
+
) -> Any:
|
|
159
|
+
resolved = combine
|
|
160
|
+
if isinstance(resolved, Mapping):
|
|
161
|
+
resolved = PerColumn(rules=dict(resolved))
|
|
162
|
+
return self._register(
|
|
163
|
+
target, labels, combine=resolved or LastWins(), is_group=True, cost=cost
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
def compile(self) -> Plan:
|
|
167
|
+
parsed: list[_ParsedChain] = [_parse_chain(raw) for raw in self._chains]
|
|
168
|
+
|
|
169
|
+
compiled_chains: list[CompiledChain] = []
|
|
170
|
+
parsed_to_compiled: dict[int, CompiledChain] = {}
|
|
171
|
+
for p in parsed:
|
|
172
|
+
labels = set(p.raw.labels)
|
|
173
|
+
if p.produces is not None:
|
|
174
|
+
labels.add(p.produces.__name__.lower())
|
|
175
|
+
for e, _cols, is_coll in p.inputs:
|
|
176
|
+
if not is_coll:
|
|
177
|
+
continue
|
|
178
|
+
if p.produces is None:
|
|
179
|
+
raise DagValidationError(
|
|
180
|
+
f"Aggregate chain reading {e.__name__} must produce "
|
|
181
|
+
f"an entity (none declared)."
|
|
182
|
+
)
|
|
183
|
+
try:
|
|
184
|
+
_find_unique_fk_to(e, p.produces)
|
|
185
|
+
except LookupError as exc:
|
|
186
|
+
raise DagValidationError(str(exc)) from exc
|
|
187
|
+
|
|
188
|
+
compiled = CompiledChain(
|
|
189
|
+
produces=p.produces if p.produces is not None else object,
|
|
190
|
+
inputs=tuple(
|
|
191
|
+
CompiledInput(entity=e, columns=cols, is_collection=is_coll)
|
|
192
|
+
for e, cols, is_coll in p.inputs
|
|
193
|
+
),
|
|
194
|
+
members=tuple(p.members),
|
|
195
|
+
combine=p.raw.combine,
|
|
196
|
+
labels=frozenset(labels),
|
|
197
|
+
cost=dict(p.raw.cost),
|
|
198
|
+
)
|
|
199
|
+
compiled_chains.append(compiled)
|
|
200
|
+
parsed_to_compiled[id(p)] = compiled
|
|
201
|
+
|
|
202
|
+
column_owners: dict[tuple[type, str], CompiledChain] = {}
|
|
203
|
+
column_kinds: dict[tuple[type, str], type] = {}
|
|
204
|
+
|
|
205
|
+
for p in parsed:
|
|
206
|
+
if p.produces is None:
|
|
207
|
+
continue
|
|
208
|
+
specs = _entity_specs(p.produces)
|
|
209
|
+
compiled = parsed_to_compiled[id(p)]
|
|
210
|
+
for col in p.produces_cols:
|
|
211
|
+
if col in SYSTEM_COLUMNS:
|
|
212
|
+
continue
|
|
213
|
+
key = (p.produces, col)
|
|
214
|
+
if key in column_owners:
|
|
215
|
+
raise DagValidationError(
|
|
216
|
+
f"Column {p.produces.__name__}.{col} produced by multiple chains"
|
|
217
|
+
)
|
|
218
|
+
column_owners[key] = compiled
|
|
219
|
+
if col in specs:
|
|
220
|
+
column_kinds[key] = specs[col].kind
|
|
221
|
+
|
|
222
|
+
edges: dict[CompiledChain, list[CompiledChain]] = {c: [] for c in compiled_chains}
|
|
223
|
+
for p in parsed:
|
|
224
|
+
compiled = parsed_to_compiled[id(p)]
|
|
225
|
+
for m in p.members:
|
|
226
|
+
for entity, cols in m.needs_entities:
|
|
227
|
+
needed_raw = list(cols) if cols else [_find_pk(entity)]
|
|
228
|
+
needed = [c for c in needed_raw if c not in SYSTEM_COLUMNS] or [
|
|
229
|
+
_find_pk(entity)
|
|
230
|
+
]
|
|
231
|
+
for col in needed:
|
|
232
|
+
owner = column_owners.get((entity, col))
|
|
233
|
+
if owner is None:
|
|
234
|
+
raise DagValidationError(
|
|
235
|
+
f"No chain produces {entity.__name__}.{col} "
|
|
236
|
+
f"(required by {_chain_label(p)})"
|
|
237
|
+
)
|
|
238
|
+
if owner is compiled:
|
|
239
|
+
continue
|
|
240
|
+
if compiled not in edges[owner]:
|
|
241
|
+
edges[owner].append(compiled)
|
|
242
|
+
|
|
243
|
+
cycle = _detect_cycle(edges)
|
|
244
|
+
if cycle is not None:
|
|
245
|
+
names = " -> ".join(c.produces.__name__ for c in cycle)
|
|
246
|
+
raise DagValidationError(f"Cycle detected: {names}")
|
|
247
|
+
|
|
248
|
+
return Plan(
|
|
249
|
+
chains=tuple(compiled_chains),
|
|
250
|
+
column_owners=column_owners,
|
|
251
|
+
column_kinds=column_kinds,
|
|
252
|
+
edges={k: tuple(v) for k, v in edges.items()},
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
def _register(
|
|
256
|
+
self,
|
|
257
|
+
target: Any,
|
|
258
|
+
labels: set[str] | None,
|
|
259
|
+
combine: Combine,
|
|
260
|
+
is_group: bool,
|
|
261
|
+
cost: Mapping[str, int] | None = None,
|
|
262
|
+
) -> Any:
|
|
263
|
+
def register(t: Any) -> Any:
|
|
264
|
+
if is_group:
|
|
265
|
+
if not isinstance(t, type):
|
|
266
|
+
raise DagValidationError(
|
|
267
|
+
"@dag.group expects a class with @member functions; "
|
|
268
|
+
"use @dag.fn for a solo function"
|
|
269
|
+
)
|
|
270
|
+
members = [
|
|
271
|
+
v
|
|
272
|
+
for v in vars(t).values()
|
|
273
|
+
if callable(v) and getattr(v, "__dagrun_member__", False)
|
|
274
|
+
]
|
|
275
|
+
if not members:
|
|
276
|
+
raise DagValidationError(
|
|
277
|
+
f"@dag.group {t.__name__} has no @member functions"
|
|
278
|
+
)
|
|
279
|
+
else:
|
|
280
|
+
if isinstance(t, type):
|
|
281
|
+
raise DagValidationError(
|
|
282
|
+
"@dag.fn expects a function; use @dag.group for a class with members"
|
|
283
|
+
)
|
|
284
|
+
members = [t]
|
|
285
|
+
auto_labels = {t.__name__}
|
|
286
|
+
merged_cost: dict[str, int] = {"concurrency": 1}
|
|
287
|
+
if cost is not None:
|
|
288
|
+
merged_cost.update(cost)
|
|
289
|
+
self._chains.append(
|
|
290
|
+
_Chain(
|
|
291
|
+
members=members,
|
|
292
|
+
combine=combine,
|
|
293
|
+
labels=auto_labels | (labels or set()),
|
|
294
|
+
cost=merged_cost,
|
|
295
|
+
)
|
|
296
|
+
)
|
|
297
|
+
return t
|
|
298
|
+
|
|
299
|
+
if target is None:
|
|
300
|
+
return register
|
|
301
|
+
return register(target)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@dataclass
|
|
305
|
+
class _ParsedChain:
|
|
306
|
+
raw: _Chain
|
|
307
|
+
produces: type | None
|
|
308
|
+
produces_cols: frozenset[str]
|
|
309
|
+
inputs: tuple[tuple[type, frozenset[str], bool], ...]
|
|
310
|
+
members: list[CompiledMember]
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _chain_label(p: _ParsedChain) -> str:
|
|
314
|
+
fn = p.raw.members[0] if p.raw.members else None
|
|
315
|
+
return getattr(fn, "__qualname__", repr(p.raw))
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _parse_chain(raw: _Chain) -> _ParsedChain:
|
|
319
|
+
members: list[CompiledMember] = []
|
|
320
|
+
produces: type | None = None
|
|
321
|
+
produces_cols_union: set[str] = set()
|
|
322
|
+
chain_inputs: tuple[tuple[type, frozenset[str], bool], ...] | None = None
|
|
323
|
+
any_member_has_same_entity_input = False
|
|
324
|
+
|
|
325
|
+
for fn in raw.members:
|
|
326
|
+
try:
|
|
327
|
+
hints = typing.get_type_hints(fn, include_extras=True)
|
|
328
|
+
except Exception:
|
|
329
|
+
hints = {}
|
|
330
|
+
sig = inspect.signature(fn)
|
|
331
|
+
|
|
332
|
+
needs_entities: list[tuple[type, frozenset[str], bool]] = []
|
|
333
|
+
needs_providers: list[type] = []
|
|
334
|
+
for pname, param in sig.parameters.items():
|
|
335
|
+
ann = hints.get(pname, param.annotation)
|
|
336
|
+
entity, cols, is_collection = _unwrap_input_param(ann)
|
|
337
|
+
if entity is not None:
|
|
338
|
+
needs_entities.append((entity, frozenset(cols), is_collection))
|
|
339
|
+
elif ann is not inspect.Parameter.empty and isinstance(ann, type):
|
|
340
|
+
needs_providers.append(ann)
|
|
341
|
+
|
|
342
|
+
ret_ann = hints.get("return", sig.return_annotation)
|
|
343
|
+
ret_entity, ret_cols = _unwrap_entity(ret_ann)
|
|
344
|
+
|
|
345
|
+
if ret_entity is not None:
|
|
346
|
+
if produces is None:
|
|
347
|
+
produces = ret_entity
|
|
348
|
+
elif produces is not ret_entity:
|
|
349
|
+
raise DagValidationError(
|
|
350
|
+
f"Members of chain disagree on produced entity: "
|
|
351
|
+
f"{produces.__name__} vs {ret_entity.__name__}"
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
if ret_entity is not None and any(
|
|
355
|
+
e is ret_entity for e, _, _ in needs_entities
|
|
356
|
+
):
|
|
357
|
+
any_member_has_same_entity_input = True
|
|
358
|
+
|
|
359
|
+
if ret_entity is not None and any(
|
|
360
|
+
is_coll for _, _, is_coll in needs_entities
|
|
361
|
+
):
|
|
362
|
+
any_member_has_same_entity_input = True
|
|
363
|
+
|
|
364
|
+
produced_cols = frozenset(ret_cols)
|
|
365
|
+
produces_cols_union |= set(ret_cols)
|
|
366
|
+
|
|
367
|
+
member_inputs = tuple(needs_entities)
|
|
368
|
+
if chain_inputs is None:
|
|
369
|
+
chain_inputs = member_inputs
|
|
370
|
+
elif {e for e, _, _ in chain_inputs} != {e for e, _, _ in member_inputs}:
|
|
371
|
+
raise DagValidationError(
|
|
372
|
+
"Members of group disagree on input entities"
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
members.append(
|
|
376
|
+
CompiledMember(
|
|
377
|
+
fn=fn,
|
|
378
|
+
needs_entities=tuple(
|
|
379
|
+
(e, c) for e, c, _ in needs_entities
|
|
380
|
+
),
|
|
381
|
+
needs_providers=tuple(needs_providers),
|
|
382
|
+
produces_columns=produced_cols,
|
|
383
|
+
)
|
|
384
|
+
)
|
|
385
|
+
|
|
386
|
+
if produces is not None and not any_member_has_same_entity_input:
|
|
387
|
+
pk_col = _find_pk(produces)
|
|
388
|
+
produces_cols_union.add(pk_col)
|
|
389
|
+
|
|
390
|
+
return _ParsedChain(
|
|
391
|
+
raw=raw,
|
|
392
|
+
produces=produces,
|
|
393
|
+
produces_cols=frozenset(produces_cols_union),
|
|
394
|
+
inputs=chain_inputs or (),
|
|
395
|
+
members=members,
|
|
396
|
+
)
|
dagrun/dag_runner.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable, Mapping
|
|
2
|
+
from datetime import timedelta
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from dagrun.dag import Dag
|
|
6
|
+
from dagrun.events import Observer
|
|
7
|
+
from dagrun.model import _internalize_columns
|
|
8
|
+
from dagrun.store import MemoryStore, Store
|
|
9
|
+
from dagrun.strategy import SingleThread, Strategy
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DagRunner:
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
store: Store | None = None,
|
|
16
|
+
*,
|
|
17
|
+
strategy: Strategy | None = None,
|
|
18
|
+
observers: list[Observer] | None = None,
|
|
19
|
+
) -> None:
|
|
20
|
+
self._store: Store = store if store is not None else MemoryStore()
|
|
21
|
+
self._strategy: Strategy = strategy if strategy is not None else SingleThread()
|
|
22
|
+
self._observers: list[Observer] = list(observers) if observers is not None else []
|
|
23
|
+
self._providers: dict[type, Callable[[], Any]] = {}
|
|
24
|
+
self._external_backends: dict[type, Any] = {}
|
|
25
|
+
|
|
26
|
+
def provide(self, kind: type, factory: Callable[[], Any]) -> None:
|
|
27
|
+
self._providers[kind] = factory
|
|
28
|
+
|
|
29
|
+
def register(self, column_kind: type, backend: Any) -> None:
|
|
30
|
+
self._external_backends[column_kind] = backend
|
|
31
|
+
|
|
32
|
+
def add_observer(self, observer: Observer) -> None:
|
|
33
|
+
self._observers.append(observer)
|
|
34
|
+
|
|
35
|
+
def execute(
|
|
36
|
+
self,
|
|
37
|
+
dag: Dag,
|
|
38
|
+
*,
|
|
39
|
+
labels: set[str] | None = None,
|
|
40
|
+
max_age: timedelta | None = None,
|
|
41
|
+
pools: Mapping[str, int] | None = None,
|
|
42
|
+
) -> None:
|
|
43
|
+
plan = dag.compile()
|
|
44
|
+
if labels is not None:
|
|
45
|
+
plan = plan.filter(labels=labels)
|
|
46
|
+
self._strategy.run(
|
|
47
|
+
plan,
|
|
48
|
+
store=self._store,
|
|
49
|
+
external_backends=self._external_backends,
|
|
50
|
+
providers=self._providers,
|
|
51
|
+
observers=self._observers,
|
|
52
|
+
max_age=max_age,
|
|
53
|
+
pools=pools,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
def get_entities(self, entity: type) -> Iterable[Any]:
|
|
57
|
+
for row in self._store.get_entities(entity):
|
|
58
|
+
hydrated = _internalize_columns(entity, row, self._external_backends)
|
|
59
|
+
yield entity(**hydrated)
|
dagrun/events.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
from collections.abc import Mapping
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Protocol
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@dataclass(frozen=True, kw_only=True)
|
|
7
|
+
class Event:
|
|
8
|
+
timestamp: float
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True, kw_only=True)
|
|
12
|
+
class ChainStarted(Event):
|
|
13
|
+
chain: Any
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
@dataclass(frozen=True, kw_only=True)
|
|
17
|
+
class ChainFinished(Event):
|
|
18
|
+
chain: Any
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass(frozen=True, kw_only=True)
|
|
22
|
+
class MemberStarted(Event):
|
|
23
|
+
chain: Any
|
|
24
|
+
member: Any
|
|
25
|
+
input_pk: Any
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@dataclass(frozen=True, kw_only=True)
|
|
29
|
+
class MemberFinished(Event):
|
|
30
|
+
chain: Any
|
|
31
|
+
member: Any
|
|
32
|
+
input_pk: Any
|
|
33
|
+
result: Any
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True, kw_only=True)
|
|
37
|
+
class MemberFailed(Event):
|
|
38
|
+
chain: Any
|
|
39
|
+
member: Any
|
|
40
|
+
input_pk: Any
|
|
41
|
+
error: BaseException
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass(frozen=True, kw_only=True)
|
|
45
|
+
class RowWritten(Event):
|
|
46
|
+
entity: type
|
|
47
|
+
pk: Any
|
|
48
|
+
columns: frozenset[str]
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
@dataclass(frozen=True, kw_only=True)
|
|
52
|
+
class RunStarted(Event):
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass(frozen=True, kw_only=True)
|
|
57
|
+
class PoolsRegistered(Event):
|
|
58
|
+
"""Published once at the start of a run with the live pool gauges.
|
|
59
|
+
|
|
60
|
+
`capacity` is the per-pool token cap; `available` is the live remaining-token
|
|
61
|
+
mapping the scheduler mutates as tasks acquire and release tokens. Both are
|
|
62
|
+
references to the scheduler's own dicts so an observer can poll utilization
|
|
63
|
+
without per-task events. Pool keys are fixed at startup, so reading these
|
|
64
|
+
without the scheduler lock is safe for display purposes.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
capacity: Mapping[str, int]
|
|
68
|
+
available: Mapping[str, int]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass(frozen=True, kw_only=True)
|
|
72
|
+
class RunFinished(Event):
|
|
73
|
+
error: BaseException | None = None
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(frozen=True, kw_only=True)
|
|
77
|
+
class TaskEnqueued(Event):
|
|
78
|
+
"""A (chain, input_pk) task was scheduled. Grows a chain's denominator."""
|
|
79
|
+
|
|
80
|
+
chain: Any
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass(frozen=True, kw_only=True)
|
|
84
|
+
class TaskCompleted(Event):
|
|
85
|
+
"""A task finished, either by running or by being skipped as fresh.
|
|
86
|
+
|
|
87
|
+
`skipped` is True when max_age suppressed the run (the work was done in a
|
|
88
|
+
prior run and is current in the store). Both cases grow a chain's numerator.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
chain: Any
|
|
92
|
+
skipped: bool
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class Observer(Protocol):
|
|
96
|
+
def on_event(self, event: Event) -> None: ...
|
dagrun/ext/__init__.py
ADDED
|
File without changes
|