py2dag 0.1.15__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {py2dag-0.1.15 → py2dag-0.2.1}/PKG-INFO +1 -1
- {py2dag-0.1.15 → py2dag-0.2.1}/py2dag/export_dagre.py +15 -2
- py2dag-0.2.1/py2dag/parser.py +500 -0
- {py2dag-0.1.15 → py2dag-0.2.1}/pyproject.toml +1 -1
- py2dag-0.1.15/py2dag/parser.py +0 -335
- {py2dag-0.1.15 → py2dag-0.2.1}/LICENSE +0 -0
- {py2dag-0.1.15 → py2dag-0.2.1}/README.md +0 -0
- {py2dag-0.1.15 → py2dag-0.2.1}/py2dag/__init__.py +0 -0
- {py2dag-0.1.15 → py2dag-0.2.1}/py2dag/cli.py +0 -0
- {py2dag-0.1.15 → py2dag-0.2.1}/py2dag/export_svg.py +0 -0
- {py2dag-0.1.15 → py2dag-0.2.1}/py2dag/pseudo.py +0 -0
@@ -48,9 +48,22 @@ HTML_TEMPLATE = """<!doctype html>
|
|
48
48
|
// Ensure edges have an object for labels/attrs to avoid TypeErrors
|
49
49
|
g.setDefaultEdgeLabel(() => ({}));
|
50
50
|
|
51
|
-
// Add op nodes
|
51
|
+
// Add op nodes with basic styling for control nodes
|
52
52
|
(plan.ops || []).forEach(op => {
|
53
|
-
|
53
|
+
let label = op.op;
|
54
|
+
let klass = 'op';
|
55
|
+
if (op.op === 'COND.eval') {
|
56
|
+
const kind = (op.args && op.args.kind) || 'if';
|
57
|
+
label = (kind.toUpperCase()) + ' ' + (op.args && op.args.expr ? op.args.expr : '');
|
58
|
+
klass = 'note';
|
59
|
+
} else if (op.op === 'ITER.eval') {
|
60
|
+
label = 'FOR ' + (op.args && op.args.expr ? op.args.expr : '');
|
61
|
+
klass = 'note';
|
62
|
+
} else if (op.op === 'PHI') {
|
63
|
+
label = 'PHI' + (op.args && op.args.var ? ` (${op.args.var})` : '');
|
64
|
+
klass = 'note';
|
65
|
+
}
|
66
|
+
g.setNode(op.id, { label, class: klass, padding: 8 });
|
54
67
|
});
|
55
68
|
|
56
69
|
// Add output nodes and edges from source to output
|
@@ -0,0 +1,500 @@
|
|
1
|
+
import ast
|
2
|
+
import json
|
3
|
+
import re
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Set
|
5
|
+
|
6
|
+
VALID_NAME_RE = re.compile(r'^[a-z_][a-z0-9_]{0,63}$')
|
7
|
+
|
8
|
+
|
9
|
+
class DSLParseError(Exception):
|
10
|
+
"""Raised when the mini-DSL constraints are violated."""
|
11
|
+
|
12
|
+
|
13
|
+
def _literal(node: ast.AST) -> Any:
|
14
|
+
"""Return a Python literal from an AST node or raise DSLParseError."""
|
15
|
+
if isinstance(node, ast.Constant):
|
16
|
+
return node.value
|
17
|
+
if isinstance(node, (ast.List, ast.Tuple)):
|
18
|
+
return [_literal(elt) for elt in node.elts]
|
19
|
+
if isinstance(node, ast.Dict):
|
20
|
+
return {_literal(k): _literal(v) for k, v in zip(node.keys, node.values)}
|
21
|
+
raise DSLParseError("Keyword argument values must be JSON-serialisable literals")
|
22
|
+
|
23
|
+
|
24
|
+
def _get_call_name(func: ast.AST) -> str:
|
25
|
+
if isinstance(func, ast.Name):
|
26
|
+
return func.id
|
27
|
+
if isinstance(func, ast.Attribute):
|
28
|
+
parts: List[str] = []
|
29
|
+
while isinstance(func, ast.Attribute):
|
30
|
+
parts.append(func.attr)
|
31
|
+
func = func.value
|
32
|
+
if isinstance(func, ast.Name):
|
33
|
+
parts.append(func.id)
|
34
|
+
return ".".join(reversed(parts))
|
35
|
+
raise DSLParseError("Only simple or attribute names are allowed for operations")
|
36
|
+
|
37
|
+
|
38
|
+
def parse(source: str, function_name: Optional[str] = None) -> Dict[str, Any]:
|
39
|
+
if len(source) > 20_000:
|
40
|
+
raise DSLParseError("Source too large")
|
41
|
+
module = ast.parse(source)
|
42
|
+
|
43
|
+
def _parse_fn(fn: ast.AST) -> Dict[str, Any]:
|
44
|
+
ops: List[Dict[str, Any]] = []
|
45
|
+
outputs: List[Dict[str, str]] = []
|
46
|
+
settings: Dict[str, Any] = {}
|
47
|
+
|
48
|
+
returned_var: Optional[str] = None
|
49
|
+
|
50
|
+
# Enforce no-args top-level function signature
|
51
|
+
try:
|
52
|
+
fargs = getattr(fn, "args") # type: ignore[attr-defined]
|
53
|
+
has_params = bool(
|
54
|
+
getattr(fargs, "posonlyargs", []) or fargs.args or fargs.vararg or fargs.kwonlyargs or fargs.kwarg
|
55
|
+
)
|
56
|
+
if has_params:
|
57
|
+
raise DSLParseError("Top-level function must not accept parameters")
|
58
|
+
except AttributeError:
|
59
|
+
pass
|
60
|
+
|
61
|
+
# SSA state
|
62
|
+
versions: Dict[str, int] = {}
|
63
|
+
latest: Dict[str, str] = {}
|
64
|
+
context_suffix: str = ""
|
65
|
+
ctx_counts: Dict[str, int] = {"if": 0, "loop": 0, "while": 0}
|
66
|
+
|
67
|
+
def _ssa_new(name: str) -> str:
|
68
|
+
if not VALID_NAME_RE.match(name):
|
69
|
+
raise DSLParseError(f"Invalid variable name: {name}")
|
70
|
+
versions[name] = versions.get(name, 0) + 1
|
71
|
+
base = f"{name}_{versions[name]}"
|
72
|
+
ssa = f"{base}@{context_suffix}" if context_suffix else base
|
73
|
+
latest[name] = ssa
|
74
|
+
return ssa
|
75
|
+
|
76
|
+
def _ssa_get(name: str) -> str:
|
77
|
+
if name not in latest:
|
78
|
+
raise DSLParseError(f"Undefined dependency: {name}")
|
79
|
+
return latest[name]
|
80
|
+
|
81
|
+
def _collect_name_loads(node: ast.AST) -> List[str]:
|
82
|
+
names: List[str] = []
|
83
|
+
for n in ast.walk(node):
|
84
|
+
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load):
|
85
|
+
if n.id not in names:
|
86
|
+
names.append(n.id)
|
87
|
+
return names
|
88
|
+
|
89
|
+
def _collect_value_deps(node: ast.AST) -> List[str]:
|
90
|
+
"""Collect variable name dependencies from an expression, excluding callee names in Call.func.
|
91
|
+
|
92
|
+
For example, for range(n) -> ['n'] (not 'range'). For cond(a) -> ['a'] (not 'cond').
|
93
|
+
For obj.attr -> ['obj'].
|
94
|
+
"""
|
95
|
+
callees: set[str] = set()
|
96
|
+
|
97
|
+
def mark_callee(func: ast.AST):
|
98
|
+
for n in ast.walk(func):
|
99
|
+
if isinstance(n, ast.Name):
|
100
|
+
callees.add(n.id)
|
101
|
+
|
102
|
+
# First collect callee name ids appearing under Call.func
|
103
|
+
for n in ast.walk(node):
|
104
|
+
if isinstance(n, ast.Call):
|
105
|
+
mark_callee(n.func)
|
106
|
+
|
107
|
+
# Then collect normal loads and drop any that are marked as callees
|
108
|
+
deps: List[str] = []
|
109
|
+
for n in ast.walk(node):
|
110
|
+
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load):
|
111
|
+
if n.id not in callees and n.id not in deps:
|
112
|
+
deps.append(n.id)
|
113
|
+
return deps
|
114
|
+
|
115
|
+
def _stringify(node: ast.AST) -> str:
|
116
|
+
try:
|
117
|
+
return ast.unparse(node) # type: ignore[attr-defined]
|
118
|
+
except Exception:
|
119
|
+
return node.__class__.__name__
|
120
|
+
|
121
|
+
def _emit_assign_from_call(var_name: str, call: ast.Call) -> str:
|
122
|
+
op_name = _get_call_name(call.func)
|
123
|
+
deps: List[str] = []
|
124
|
+
|
125
|
+
def _expand_star_name(ssa_var: str) -> List[str]:
|
126
|
+
# Expand if previous op was a PACK.*
|
127
|
+
for prev in reversed(ops):
|
128
|
+
if prev.get("id") == ssa_var and prev.get("op") in {"PACK.list", "PACK.tuple"}:
|
129
|
+
return list(prev.get("deps", []))
|
130
|
+
return [ssa_var]
|
131
|
+
|
132
|
+
for arg in call.args:
|
133
|
+
if isinstance(arg, ast.Starred):
|
134
|
+
star_val = arg.value
|
135
|
+
if isinstance(star_val, ast.Name):
|
136
|
+
deps.extend(_expand_star_name(_ssa_get(star_val.id)))
|
137
|
+
elif isinstance(star_val, (ast.List, ast.Tuple)):
|
138
|
+
for elt in star_val.elts:
|
139
|
+
if not isinstance(elt, ast.Name):
|
140
|
+
raise DSLParseError("Starred list/tuple elements must be names")
|
141
|
+
deps.append(_ssa_get(elt.id))
|
142
|
+
else:
|
143
|
+
raise DSLParseError("*args must be a name or list/tuple of names")
|
144
|
+
elif isinstance(arg, ast.Name):
|
145
|
+
deps.append(_ssa_get(arg.id))
|
146
|
+
elif isinstance(arg, (ast.List, ast.Tuple)):
|
147
|
+
for elt in arg.elts:
|
148
|
+
if not isinstance(elt, ast.Name):
|
149
|
+
raise DSLParseError("List/Tuple positional args must be variable names")
|
150
|
+
deps.append(_ssa_get(elt.id))
|
151
|
+
else:
|
152
|
+
raise DSLParseError("Positional args must be variable names or lists/tuples of names")
|
153
|
+
|
154
|
+
kwargs: Dict[str, Any] = {}
|
155
|
+
for kw in call.keywords:
|
156
|
+
if kw.arg is None:
|
157
|
+
v = kw.value
|
158
|
+
if isinstance(v, ast.Dict):
|
159
|
+
lit = _literal(v)
|
160
|
+
for k, val in lit.items():
|
161
|
+
kwargs[str(k)] = val
|
162
|
+
elif isinstance(v, ast.Name):
|
163
|
+
deps.append(_ssa_get(v.id))
|
164
|
+
else:
|
165
|
+
raise DSLParseError("**kwargs must be a dict literal or a variable name")
|
166
|
+
else:
|
167
|
+
if isinstance(kw.value, ast.Name):
|
168
|
+
deps.append(_ssa_get(kw.value.id))
|
169
|
+
else:
|
170
|
+
kwargs[kw.arg] = _literal(kw.value)
|
171
|
+
|
172
|
+
ssa = _ssa_new(var_name)
|
173
|
+
ops.append({"id": ssa, "op": op_name, "deps": deps, "args": kwargs})
|
174
|
+
return ssa
|
175
|
+
|
176
|
+
def _emit_assign_from_fstring(var_name: str, fstr: ast.JoinedStr) -> str:
|
177
|
+
deps: List[str] = []
|
178
|
+
parts: List[str] = []
|
179
|
+
for item in fstr.values:
|
180
|
+
if isinstance(item, ast.Constant) and isinstance(item.value, str):
|
181
|
+
parts.append(item.value)
|
182
|
+
elif isinstance(item, ast.FormattedValue) and isinstance(item.value, ast.Name):
|
183
|
+
deps.append(_ssa_get(item.value.id))
|
184
|
+
parts.append("{" + str(len(deps) - 1) + "}")
|
185
|
+
else:
|
186
|
+
raise DSLParseError("f-strings may only contain variable names")
|
187
|
+
template = "".join(parts)
|
188
|
+
ssa = _ssa_new(var_name)
|
189
|
+
ops.append({
|
190
|
+
"id": ssa,
|
191
|
+
"op": "TEXT.format",
|
192
|
+
"deps": deps,
|
193
|
+
"args": {"template": template},
|
194
|
+
})
|
195
|
+
return ssa
|
196
|
+
|
197
|
+
def _emit_assign_from_literal_or_pack(var_name: str, value: ast.AST) -> str:
|
198
|
+
try:
|
199
|
+
lit = _literal(value)
|
200
|
+
ssa = _ssa_new(var_name)
|
201
|
+
ops.append({
|
202
|
+
"id": ssa,
|
203
|
+
"op": "CONST.value",
|
204
|
+
"deps": [],
|
205
|
+
"args": {"value": lit},
|
206
|
+
})
|
207
|
+
return ssa
|
208
|
+
except DSLParseError:
|
209
|
+
if isinstance(value, (ast.List, ast.Tuple)):
|
210
|
+
elts = value.elts
|
211
|
+
deps: List[str] = []
|
212
|
+
for elt in elts:
|
213
|
+
if not isinstance(elt, ast.Name):
|
214
|
+
raise DSLParseError("Only names allowed in non-literal list/tuple assignment")
|
215
|
+
deps.append(_ssa_get(elt.id))
|
216
|
+
kind = "list" if isinstance(value, ast.List) else "tuple"
|
217
|
+
ssa = _ssa_new(var_name)
|
218
|
+
ops.append({
|
219
|
+
"id": ssa,
|
220
|
+
"op": f"PACK.{kind}",
|
221
|
+
"deps": deps,
|
222
|
+
"args": {},
|
223
|
+
})
|
224
|
+
return ssa
|
225
|
+
raise
|
226
|
+
|
227
|
+
def _emit_assign_from_comp(var_name: str, node: ast.AST) -> str:
|
228
|
+
name_deps = [n for n in _collect_name_loads(node) if n in latest]
|
229
|
+
for n in name_deps:
|
230
|
+
if n not in latest:
|
231
|
+
raise DSLParseError(f"Undefined dependency: {n}")
|
232
|
+
kind = (
|
233
|
+
"listcomp" if isinstance(node, ast.ListComp) else
|
234
|
+
"setcomp" if isinstance(node, ast.SetComp) else
|
235
|
+
"dictcomp" if isinstance(node, ast.DictComp) else
|
236
|
+
"genexpr"
|
237
|
+
)
|
238
|
+
deps = [_ssa_get(n) for n in name_deps]
|
239
|
+
ssa = _ssa_new(var_name)
|
240
|
+
ops.append({
|
241
|
+
"id": ssa,
|
242
|
+
"op": f"COMP.{kind}",
|
243
|
+
"deps": deps,
|
244
|
+
"args": {},
|
245
|
+
})
|
246
|
+
return ssa
|
247
|
+
|
248
|
+
def _emit_cond(node: ast.AST, kind: str = "if") -> str:
|
249
|
+
expr = _stringify(node)
|
250
|
+
deps = [_ssa_get(n) for n in _collect_value_deps(node)]
|
251
|
+
ssa = _ssa_new("cond")
|
252
|
+
ops.append({"id": ssa, "op": "COND.eval", "deps": deps, "args": {"expr": expr, "kind": kind}})
|
253
|
+
return ssa
|
254
|
+
|
255
|
+
def _emit_iter(node: ast.AST) -> str:
|
256
|
+
expr = _stringify(node)
|
257
|
+
deps = [_ssa_get(n) for n in _collect_value_deps(node)]
|
258
|
+
ssa = _ssa_new("iter")
|
259
|
+
ops.append({"id": ssa, "op": "ITER.eval", "deps": deps, "args": {"expr": expr, "kind": "for"}})
|
260
|
+
return ssa
|
261
|
+
|
262
|
+
def _parse_stmt(stmt: ast.stmt) -> Optional[str]:
|
263
|
+
nonlocal returned_var, versions, latest, context_suffix
|
264
|
+
if isinstance(stmt, ast.Assign):
|
265
|
+
if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name):
|
266
|
+
raise DSLParseError("Assignment targets must be simple names")
|
267
|
+
var_name = stmt.targets[0].id
|
268
|
+
value = stmt.value
|
269
|
+
if isinstance(value, ast.Await):
|
270
|
+
value = value.value
|
271
|
+
if isinstance(value, ast.Call):
|
272
|
+
return _emit_assign_from_call(var_name, value)
|
273
|
+
elif isinstance(value, ast.JoinedStr):
|
274
|
+
return _emit_assign_from_fstring(var_name, value)
|
275
|
+
elif isinstance(value, (ast.Constant, ast.List, ast.Tuple, ast.Dict)):
|
276
|
+
return _emit_assign_from_literal_or_pack(var_name, value)
|
277
|
+
elif isinstance(value, (ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)):
|
278
|
+
return _emit_assign_from_comp(var_name, value)
|
279
|
+
else:
|
280
|
+
raise DSLParseError("Right hand side must be a call or f-string")
|
281
|
+
elif isinstance(stmt, ast.Expr):
|
282
|
+
call = stmt.value
|
283
|
+
if isinstance(call, ast.Await):
|
284
|
+
call = call.value
|
285
|
+
if not isinstance(call, ast.Call):
|
286
|
+
raise DSLParseError("Only call expressions allowed at top level")
|
287
|
+
name = _get_call_name(call.func)
|
288
|
+
if name == "settings":
|
289
|
+
for kw in call.keywords:
|
290
|
+
if kw.arg is None:
|
291
|
+
raise DSLParseError("settings does not accept **kwargs")
|
292
|
+
settings[kw.arg] = _literal(kw.value)
|
293
|
+
if call.args:
|
294
|
+
raise DSLParseError("settings only accepts keyword literals")
|
295
|
+
elif name == "output":
|
296
|
+
if len(call.args) != 1 or not isinstance(call.args[0], ast.Name):
|
297
|
+
raise DSLParseError("output requires a single variable name argument")
|
298
|
+
var = call.args[0].id
|
299
|
+
ssa_from = _ssa_get(var)
|
300
|
+
filename = None
|
301
|
+
for kw in call.keywords:
|
302
|
+
if kw.arg in {"as", "as_"}:
|
303
|
+
filename = _literal(kw.value)
|
304
|
+
else:
|
305
|
+
raise DSLParseError("output only accepts 'as' keyword")
|
306
|
+
if filename is None or not isinstance(filename, str):
|
307
|
+
raise DSLParseError("output requires as=\"filename\"")
|
308
|
+
outputs.append({"from": ssa_from, "as": filename})
|
309
|
+
else:
|
310
|
+
raise DSLParseError("Only settings() and output() calls allowed as expressions")
|
311
|
+
return None
|
312
|
+
elif isinstance(stmt, ast.Return):
|
313
|
+
if isinstance(stmt.value, ast.Name):
|
314
|
+
returned_var = _ssa_get(stmt.value.id)
|
315
|
+
elif isinstance(stmt.value, (ast.Constant, ast.List, ast.Tuple, ast.Dict)):
|
316
|
+
lit = _literal(stmt.value)
|
317
|
+
const_id = _ssa_new("return_value")
|
318
|
+
ops.append({
|
319
|
+
"id": const_id,
|
320
|
+
"op": "CONST.value",
|
321
|
+
"deps": [],
|
322
|
+
"args": {"value": lit},
|
323
|
+
})
|
324
|
+
returned_var = const_id
|
325
|
+
else:
|
326
|
+
raise DSLParseError("return must return a variable name or literal")
|
327
|
+
return None
|
328
|
+
elif isinstance(stmt, ast.If):
|
329
|
+
# Evaluate condition
|
330
|
+
cond_id = _emit_cond(stmt.test, kind="if")
|
331
|
+
# Save pre-branch state
|
332
|
+
pre_versions = dict(versions)
|
333
|
+
pre_latest = dict(latest)
|
334
|
+
|
335
|
+
# THEN branch
|
336
|
+
then_ops_start = len(ops)
|
337
|
+
versions_then = dict(pre_versions)
|
338
|
+
latest_then = dict(pre_latest)
|
339
|
+
# Run then body with local state and context
|
340
|
+
saved_versions, saved_latest = versions, latest
|
341
|
+
saved_ctx = context_suffix
|
342
|
+
ctx_counts["if"] += 1
|
343
|
+
context_suffix = f"then{ctx_counts['if']}"
|
344
|
+
versions, latest = versions_then, latest_then
|
345
|
+
for inner in stmt.body:
|
346
|
+
_parse_stmt(inner)
|
347
|
+
versions_then, latest_then = versions, latest
|
348
|
+
versions, latest = saved_versions, saved_latest
|
349
|
+
context_suffix = saved_ctx
|
350
|
+
|
351
|
+
# ELSE branch
|
352
|
+
else_ops_start = len(ops)
|
353
|
+
versions_else = dict(pre_versions)
|
354
|
+
latest_else = dict(pre_latest)
|
355
|
+
saved_versions, saved_latest = versions, latest
|
356
|
+
saved_ctx = context_suffix
|
357
|
+
context_suffix = f"else{ctx_counts['if']}"
|
358
|
+
versions, latest = versions_else, latest_else
|
359
|
+
for inner in stmt.orelse or []:
|
360
|
+
_parse_stmt(inner)
|
361
|
+
versions_else, latest_else = versions, latest
|
362
|
+
versions, latest = saved_versions, saved_latest
|
363
|
+
context_suffix = saved_ctx
|
364
|
+
|
365
|
+
# Add cond dep to first op in each branch, if any
|
366
|
+
if len(ops) > then_ops_start:
|
367
|
+
ops[then_ops_start]["deps"] = [*ops[then_ops_start].get("deps", []), cond_id]
|
368
|
+
if len(ops) > else_ops_start:
|
369
|
+
ops[else_ops_start]["deps"] = [*ops[else_ops_start].get("deps", []), cond_id]
|
370
|
+
|
371
|
+
# Determine variables assigned in branches
|
372
|
+
then_assigned = {k for k in latest_then if pre_latest.get(k) != latest_then.get(k)}
|
373
|
+
else_assigned = {k for k in latest_else if pre_latest.get(k) != latest_else.get(k)}
|
374
|
+
all_assigned = then_assigned | else_assigned
|
375
|
+
for var in sorted(all_assigned):
|
376
|
+
left = latest_then.get(var, pre_latest.get(var))
|
377
|
+
right = latest_else.get(var, pre_latest.get(var))
|
378
|
+
if left is None or right is None:
|
379
|
+
# Variable does not exist pre-branch on one side; skip making it available post-merge
|
380
|
+
continue
|
381
|
+
phi_id = _ssa_new(var)
|
382
|
+
ops.append({"id": phi_id, "op": "PHI", "deps": [left, right], "args": {"var": var}})
|
383
|
+
return None
|
384
|
+
elif isinstance(stmt, (ast.For, ast.AsyncFor)):
|
385
|
+
# ITER over iterable
|
386
|
+
iter_id = _emit_iter(stmt.iter)
|
387
|
+
# Save pre-loop state
|
388
|
+
pre_versions = dict(versions)
|
389
|
+
pre_latest = dict(latest)
|
390
|
+
# Body state copy
|
391
|
+
body_ops_start = len(ops)
|
392
|
+
versions_body = dict(pre_versions)
|
393
|
+
latest_body = dict(pre_latest)
|
394
|
+
saved_versions, saved_latest = versions, latest
|
395
|
+
saved_ctx = context_suffix
|
396
|
+
ctx_counts["loop"] += 1
|
397
|
+
context_suffix = f"loop{ctx_counts['loop']}"
|
398
|
+
versions, latest = versions_body, latest_body
|
399
|
+
for inner in stmt.body:
|
400
|
+
_parse_stmt(inner)
|
401
|
+
versions_body, latest_body = versions, latest
|
402
|
+
versions, latest = saved_versions, saved_latest
|
403
|
+
context_suffix = saved_ctx
|
404
|
+
# Add iter dep to first op in body
|
405
|
+
if len(ops) > body_ops_start:
|
406
|
+
ops[body_ops_start]["deps"] = [*ops[body_ops_start].get("deps", []), iter_id]
|
407
|
+
# Loop-carried vars: only those existing pre-loop and reassigned in body
|
408
|
+
changed = {k for k in latest_body if pre_latest.get(k) != latest_body.get(k)}
|
409
|
+
carried = [k for k in changed if k in pre_latest]
|
410
|
+
for var in sorted(carried):
|
411
|
+
phi_id = _ssa_new(var)
|
412
|
+
ops.append({
|
413
|
+
"id": phi_id,
|
414
|
+
"op": "PHI",
|
415
|
+
"deps": [pre_latest[var], latest_body[var]],
|
416
|
+
"args": {"var": var},
|
417
|
+
})
|
418
|
+
return None
|
419
|
+
elif isinstance(stmt, ast.While):
|
420
|
+
cond_id = _emit_cond(stmt.test, kind="while")
|
421
|
+
pre_versions = dict(versions)
|
422
|
+
pre_latest = dict(latest)
|
423
|
+
body_ops_start = len(ops)
|
424
|
+
versions_body = dict(pre_versions)
|
425
|
+
latest_body = dict(pre_latest)
|
426
|
+
saved_versions, saved_latest = versions, latest
|
427
|
+
saved_ctx = context_suffix
|
428
|
+
ctx_counts["while"] += 1
|
429
|
+
context_suffix = f"while{ctx_counts['while']}"
|
430
|
+
versions, latest = versions_body, latest_body
|
431
|
+
for inner in stmt.body:
|
432
|
+
_parse_stmt(inner)
|
433
|
+
versions_body, latest_body = versions, latest
|
434
|
+
versions, latest = saved_versions, saved_latest
|
435
|
+
context_suffix = saved_ctx
|
436
|
+
if len(ops) > body_ops_start:
|
437
|
+
ops[body_ops_start]["deps"] = [*ops[body_ops_start].get("deps", []), cond_id]
|
438
|
+
changed = {k for k in latest_body if pre_latest.get(k) != latest_body.get(k)}
|
439
|
+
carried = [k for k in changed if k in pre_latest]
|
440
|
+
for var in sorted(carried):
|
441
|
+
phi_id = _ssa_new(var)
|
442
|
+
ops.append({
|
443
|
+
"id": phi_id,
|
444
|
+
"op": "PHI",
|
445
|
+
"deps": [pre_latest[var], latest_body[var]],
|
446
|
+
"args": {"var": var},
|
447
|
+
})
|
448
|
+
return None
|
449
|
+
elif isinstance(stmt, (ast.Pass,)):
|
450
|
+
return None
|
451
|
+
else:
|
452
|
+
raise DSLParseError("Only assignments, control flow, settings/output calls, and return are allowed in function body")
|
453
|
+
|
454
|
+
# Parse body sequentially; still require a resulting output
|
455
|
+
for i, stmt in enumerate(fn.body): # type: ignore[attr-defined]
|
456
|
+
_parse_stmt(stmt)
|
457
|
+
|
458
|
+
if not outputs:
|
459
|
+
if returned_var is not None:
|
460
|
+
outputs.append({"from": returned_var, "as": "return"})
|
461
|
+
else:
|
462
|
+
raise DSLParseError("At least one output() call required")
|
463
|
+
if len(ops) > 2000:
|
464
|
+
raise DSLParseError("Too many operations")
|
465
|
+
|
466
|
+
fn_name = getattr(fn, "name", None) # type: ignore[attr-defined]
|
467
|
+
plan: Dict[str, Any] = {"version": 2, "function": fn_name, "ops": ops, "outputs": outputs}
|
468
|
+
if settings:
|
469
|
+
plan["settings"] = settings
|
470
|
+
return plan
|
471
|
+
|
472
|
+
# If a specific function name is provided, use it; otherwise try to auto-detect
|
473
|
+
if function_name is not None:
|
474
|
+
fn: Optional[ast.AST] = None
|
475
|
+
for node in module.body:
|
476
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == function_name:
|
477
|
+
fn = node
|
478
|
+
break
|
479
|
+
if fn is None:
|
480
|
+
raise DSLParseError(f"Function {function_name!r} not found")
|
481
|
+
return _parse_fn(fn)
|
482
|
+
else:
|
483
|
+
last_err: Optional[Exception] = None
|
484
|
+
for node in module.body:
|
485
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
486
|
+
try:
|
487
|
+
return _parse_fn(node)
|
488
|
+
except DSLParseError as e:
|
489
|
+
last_err = e
|
490
|
+
continue
|
491
|
+
# If we got here, either there are no functions or none matched the DSL
|
492
|
+
if last_err is not None:
|
493
|
+
raise DSLParseError("No suitable function matched the DSL; specify --func to disambiguate") from last_err
|
494
|
+
raise DSLParseError("No function definitions found in source")
|
495
|
+
|
496
|
+
|
497
|
+
def parse_file(filename: str, function_name: Optional[str] = None) -> Dict[str, Any]:
|
498
|
+
with open(filename, "r", encoding="utf-8") as f:
|
499
|
+
src = f.read()
|
500
|
+
return parse(src, function_name=function_name)
|
py2dag-0.1.15/py2dag/parser.py
DELETED
@@ -1,335 +0,0 @@
|
|
1
|
-
import ast
|
2
|
-
import json
|
3
|
-
import re
|
4
|
-
from typing import Any, Dict, List, Optional
|
5
|
-
|
6
|
-
VALID_NAME_RE = re.compile(r'^[a-z_][a-z0-9_]{0,63}$')
|
7
|
-
|
8
|
-
|
9
|
-
class DSLParseError(Exception):
|
10
|
-
"""Raised when the mini-DSL constraints are violated."""
|
11
|
-
|
12
|
-
|
13
|
-
def _literal(node: ast.AST) -> Any:
|
14
|
-
"""Return a Python literal from an AST node or raise DSLParseError."""
|
15
|
-
if isinstance(node, ast.Constant):
|
16
|
-
return node.value
|
17
|
-
if isinstance(node, (ast.List, ast.Tuple)):
|
18
|
-
return [_literal(elt) for elt in node.elts]
|
19
|
-
if isinstance(node, ast.Dict):
|
20
|
-
return {_literal(k): _literal(v) for k, v in zip(node.keys, node.values)}
|
21
|
-
raise DSLParseError("Keyword argument values must be JSON-serialisable literals")
|
22
|
-
|
23
|
-
|
24
|
-
def _get_call_name(func: ast.AST) -> str:
|
25
|
-
if isinstance(func, ast.Name):
|
26
|
-
return func.id
|
27
|
-
if isinstance(func, ast.Attribute):
|
28
|
-
parts: List[str] = []
|
29
|
-
while isinstance(func, ast.Attribute):
|
30
|
-
parts.append(func.attr)
|
31
|
-
func = func.value
|
32
|
-
if isinstance(func, ast.Name):
|
33
|
-
parts.append(func.id)
|
34
|
-
return ".".join(reversed(parts))
|
35
|
-
raise DSLParseError("Only simple or attribute names are allowed for operations")
|
36
|
-
|
37
|
-
|
38
|
-
def parse(source: str, function_name: Optional[str] = None) -> Dict[str, Any]:
|
39
|
-
if len(source) > 20_000:
|
40
|
-
raise DSLParseError("Source too large")
|
41
|
-
module = ast.parse(source)
|
42
|
-
|
43
|
-
def _parse_fn(fn: ast.AST) -> Dict[str, Any]:
|
44
|
-
defined: set[str] = set()
|
45
|
-
ops: List[Dict[str, Any]] = []
|
46
|
-
outputs: List[Dict[str, str]] = []
|
47
|
-
settings: Dict[str, Any] = {}
|
48
|
-
|
49
|
-
returned_var: Optional[str] = None
|
50
|
-
# Enforce no-args top-level function signature
|
51
|
-
try:
|
52
|
-
fargs = getattr(fn, "args") # type: ignore[attr-defined]
|
53
|
-
has_params = bool(
|
54
|
-
getattr(fargs, "posonlyargs", []) or fargs.args or fargs.vararg or fargs.kwonlyargs or fargs.kwarg
|
55
|
-
)
|
56
|
-
if has_params:
|
57
|
-
raise DSLParseError("Top-level function must not accept parameters")
|
58
|
-
except AttributeError:
|
59
|
-
# If args not present, ignore
|
60
|
-
pass
|
61
|
-
|
62
|
-
def _collect_name_deps(node: ast.AST) -> List[str]:
|
63
|
-
names: List[str] = []
|
64
|
-
for n in ast.walk(node):
|
65
|
-
if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load):
|
66
|
-
if n.id not in names:
|
67
|
-
names.append(n.id)
|
68
|
-
return names
|
69
|
-
# type: ignore[attr-defined]
|
70
|
-
for i, stmt in enumerate(fn.body): # type: ignore[attr-defined]
|
71
|
-
if isinstance(stmt, ast.Assign):
|
72
|
-
if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name):
|
73
|
-
raise DSLParseError("Assignment targets must be simple names")
|
74
|
-
var_name = stmt.targets[0].id
|
75
|
-
if not VALID_NAME_RE.match(var_name):
|
76
|
-
raise DSLParseError(f"Invalid variable name: {var_name}")
|
77
|
-
if var_name in defined:
|
78
|
-
raise DSLParseError(f"Duplicate variable name: {var_name}")
|
79
|
-
|
80
|
-
value = stmt.value
|
81
|
-
if isinstance(value, ast.Await):
|
82
|
-
value = value.value
|
83
|
-
if isinstance(value, ast.Call):
|
84
|
-
op_name = _get_call_name(value.func)
|
85
|
-
|
86
|
-
deps: List[str] = []
|
87
|
-
|
88
|
-
def _expand_star_name(varname: str) -> List[str]:
|
89
|
-
# Try to expand a previously packed list/tuple variable into its element deps
|
90
|
-
for prev in reversed(ops):
|
91
|
-
if prev.get("id") == varname:
|
92
|
-
if prev.get("op") in {"PACK.list", "PACK.tuple"}:
|
93
|
-
return list(prev.get("deps", []))
|
94
|
-
break
|
95
|
-
return [varname]
|
96
|
-
for arg in value.args:
|
97
|
-
if isinstance(arg, ast.Starred):
|
98
|
-
star_val = arg.value
|
99
|
-
if isinstance(star_val, ast.Name):
|
100
|
-
if star_val.id not in defined:
|
101
|
-
raise DSLParseError(f"Undefined dependency: {star_val.id}")
|
102
|
-
deps.extend(_expand_star_name(star_val.id))
|
103
|
-
elif isinstance(star_val, (ast.List, ast.Tuple)):
|
104
|
-
for elt in star_val.elts:
|
105
|
-
if not isinstance(elt, ast.Name):
|
106
|
-
raise DSLParseError("Starred list/tuple elements must be names")
|
107
|
-
if elt.id not in defined:
|
108
|
-
raise DSLParseError(f"Undefined dependency: {elt.id}")
|
109
|
-
deps.append(elt.id)
|
110
|
-
else:
|
111
|
-
raise DSLParseError("*args must be a name or list/tuple of names")
|
112
|
-
elif isinstance(arg, ast.Name):
|
113
|
-
if arg.id not in defined:
|
114
|
-
raise DSLParseError(f"Undefined dependency: {arg.id}")
|
115
|
-
deps.append(arg.id)
|
116
|
-
elif isinstance(arg, (ast.List, ast.Tuple)):
|
117
|
-
for elt in arg.elts:
|
118
|
-
if not isinstance(elt, ast.Name):
|
119
|
-
raise DSLParseError("List/Tuple positional args must be variable names")
|
120
|
-
if elt.id not in defined:
|
121
|
-
raise DSLParseError(f"Undefined dependency: {elt.id}")
|
122
|
-
deps.append(elt.id)
|
123
|
-
else:
|
124
|
-
raise DSLParseError("Positional args must be variable names or lists/tuples of names")
|
125
|
-
|
126
|
-
kwargs: Dict[str, Any] = {}
|
127
|
-
for kw in value.keywords:
|
128
|
-
if kw.arg is None:
|
129
|
-
# **kwargs support: allow dict literal merge, or variable name as dep
|
130
|
-
v = kw.value
|
131
|
-
if isinstance(v, ast.Dict):
|
132
|
-
# Merge literal kwargs
|
133
|
-
lit = _literal(v)
|
134
|
-
for k, val in lit.items():
|
135
|
-
kwargs[str(k)] = val
|
136
|
-
elif isinstance(v, ast.Name):
|
137
|
-
if v.id not in defined:
|
138
|
-
raise DSLParseError(f"Undefined dependency: {v.id}")
|
139
|
-
deps.append(v.id)
|
140
|
-
else:
|
141
|
-
raise DSLParseError("**kwargs must be a dict literal or a variable name")
|
142
|
-
else:
|
143
|
-
# Support variable-name keyword args as dependencies; literals remain in args
|
144
|
-
if isinstance(kw.value, ast.Name):
|
145
|
-
name = kw.value.id
|
146
|
-
if name not in defined:
|
147
|
-
raise DSLParseError(f"Undefined dependency: {name}")
|
148
|
-
deps.append(name)
|
149
|
-
else:
|
150
|
-
kwargs[kw.arg] = _literal(kw.value)
|
151
|
-
|
152
|
-
ops.append({"id": var_name, "op": op_name, "deps": deps, "args": kwargs})
|
153
|
-
elif isinstance(value, ast.JoinedStr):
|
154
|
-
# Minimal f-string support: only variable placeholders
|
155
|
-
deps: List[str] = []
|
156
|
-
parts: List[str] = []
|
157
|
-
for item in value.values:
|
158
|
-
if isinstance(item, ast.Constant) and isinstance(item.value, str):
|
159
|
-
parts.append(item.value)
|
160
|
-
elif isinstance(item, ast.FormattedValue) and isinstance(item.value, ast.Name):
|
161
|
-
name = item.value.id
|
162
|
-
if name not in defined:
|
163
|
-
raise DSLParseError(f"Undefined dependency: {name}")
|
164
|
-
deps.append(name)
|
165
|
-
parts.append("{" + str(len(deps) - 1) + "}")
|
166
|
-
else:
|
167
|
-
raise DSLParseError("f-strings may only contain variable names")
|
168
|
-
template = "".join(parts)
|
169
|
-
ops.append({
|
170
|
-
"id": var_name,
|
171
|
-
"op": "TEXT.format",
|
172
|
-
"deps": deps,
|
173
|
-
"args": {"template": template},
|
174
|
-
})
|
175
|
-
elif isinstance(value, (ast.Constant, ast.List, ast.Tuple, ast.Dict)):
|
176
|
-
# Allow assigning literals; also support packing lists/tuples of names
|
177
|
-
try:
|
178
|
-
lit = _literal(value)
|
179
|
-
ops.append({
|
180
|
-
"id": var_name,
|
181
|
-
"op": "CONST.value",
|
182
|
-
"deps": [],
|
183
|
-
"args": {"value": lit},
|
184
|
-
})
|
185
|
-
except DSLParseError:
|
186
|
-
if isinstance(value, (ast.List, ast.Tuple)):
|
187
|
-
elts = value.elts
|
188
|
-
names: List[str] = []
|
189
|
-
for elt in elts:
|
190
|
-
if not isinstance(elt, ast.Name):
|
191
|
-
raise DSLParseError("Only names allowed in non-literal list/tuple assignment")
|
192
|
-
if elt.id not in defined:
|
193
|
-
raise DSLParseError(f"Undefined dependency: {elt.id}")
|
194
|
-
names.append(elt.id)
|
195
|
-
kind = "list" if isinstance(value, ast.List) else "tuple"
|
196
|
-
ops.append({
|
197
|
-
"id": var_name,
|
198
|
-
"op": f"PACK.{kind}",
|
199
|
-
"deps": names,
|
200
|
-
"args": {},
|
201
|
-
})
|
202
|
-
else:
|
203
|
-
raise
|
204
|
-
elif isinstance(value, (ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)):
|
205
|
-
# Basic comprehension support: collect name deps and emit a generic comp op
|
206
|
-
name_deps = [n for n in _collect_name_deps(value) if n in defined]
|
207
|
-
# Ensure no undefined names used
|
208
|
-
for n in name_deps:
|
209
|
-
if n not in defined:
|
210
|
-
raise DSLParseError(f"Undefined dependency: {n}")
|
211
|
-
kind = (
|
212
|
-
"listcomp" if isinstance(value, ast.ListComp) else
|
213
|
-
"setcomp" if isinstance(value, ast.SetComp) else
|
214
|
-
"dictcomp" if isinstance(value, ast.DictComp) else
|
215
|
-
"genexpr"
|
216
|
-
)
|
217
|
-
ops.append({
|
218
|
-
"id": var_name,
|
219
|
-
"op": f"COMP.{kind}",
|
220
|
-
"deps": name_deps,
|
221
|
-
"args": {},
|
222
|
-
})
|
223
|
-
else:
|
224
|
-
raise DSLParseError("Right hand side must be a call or f-string")
|
225
|
-
defined.add(var_name)
|
226
|
-
|
227
|
-
elif isinstance(stmt, ast.Expr):
|
228
|
-
call = stmt.value
|
229
|
-
if isinstance(call, ast.Await):
|
230
|
-
call = call.value
|
231
|
-
if not isinstance(call, ast.Call):
|
232
|
-
raise DSLParseError("Only call expressions allowed at top level")
|
233
|
-
name = _get_call_name(call.func)
|
234
|
-
if name == "settings":
|
235
|
-
for kw in call.keywords:
|
236
|
-
if kw.arg is None:
|
237
|
-
raise DSLParseError("settings does not accept **kwargs")
|
238
|
-
settings[kw.arg] = _literal(kw.value)
|
239
|
-
if call.args:
|
240
|
-
raise DSLParseError("settings only accepts keyword literals")
|
241
|
-
elif name == "output":
|
242
|
-
if len(call.args) != 1 or not isinstance(call.args[0], ast.Name):
|
243
|
-
raise DSLParseError("output requires a single variable name argument")
|
244
|
-
var = call.args[0].id
|
245
|
-
if var not in defined:
|
246
|
-
raise DSLParseError(f"Undefined output variable: {var}")
|
247
|
-
filename = None
|
248
|
-
for kw in call.keywords:
|
249
|
-
if kw.arg in {"as", "as_"}:
|
250
|
-
filename = _literal(kw.value)
|
251
|
-
else:
|
252
|
-
raise DSLParseError("output only accepts 'as' keyword")
|
253
|
-
if filename is None or not isinstance(filename, str):
|
254
|
-
raise DSLParseError("output requires as=\"filename\"")
|
255
|
-
outputs.append({"from": var, "as": filename})
|
256
|
-
else:
|
257
|
-
raise DSLParseError("Only settings() and output() calls allowed as expressions")
|
258
|
-
elif isinstance(stmt, ast.Return):
|
259
|
-
if i != len(fn.body) - 1: # type: ignore[index]
|
260
|
-
raise DSLParseError("return must be the last statement")
|
261
|
-
if isinstance(stmt.value, ast.Name):
|
262
|
-
var = stmt.value.id
|
263
|
-
if var not in defined:
|
264
|
-
raise DSLParseError(f"Undefined return variable: {var}")
|
265
|
-
returned_var = var
|
266
|
-
elif isinstance(stmt.value, (ast.Constant, ast.List, ast.Tuple, ast.Dict)):
|
267
|
-
# Support returning a JSON-serialisable literal (str/num/bool/None, list/tuple, dict)
|
268
|
-
lit = _literal(stmt.value)
|
269
|
-
const_id_base = "return_value"
|
270
|
-
const_id = const_id_base
|
271
|
-
n = 1
|
272
|
-
while const_id in defined:
|
273
|
-
const_id = f"{const_id_base}_{n}"
|
274
|
-
n += 1
|
275
|
-
ops.append({
|
276
|
-
"id": const_id,
|
277
|
-
"op": "CONST.value",
|
278
|
-
"deps": [],
|
279
|
-
"args": {"value": lit},
|
280
|
-
})
|
281
|
-
returned_var = const_id
|
282
|
-
else:
|
283
|
-
raise DSLParseError("return must return a variable name or literal")
|
284
|
-
elif isinstance(stmt, (ast.For, ast.AsyncFor, ast.While, ast.If, ast.Match)):
|
285
|
-
# Ignore control flow blocks; only top-level linear statements are modeled
|
286
|
-
continue
|
287
|
-
elif isinstance(stmt, (ast.Pass,)):
|
288
|
-
continue
|
289
|
-
else:
|
290
|
-
raise DSLParseError("Only assignments, expression calls, and a final return are allowed in function body")
|
291
|
-
|
292
|
-
if not outputs:
|
293
|
-
if returned_var is not None:
|
294
|
-
outputs.append({"from": returned_var, "as": "return"})
|
295
|
-
else:
|
296
|
-
raise DSLParseError("At least one output() call required")
|
297
|
-
if len(ops) > 200:
|
298
|
-
raise DSLParseError("Too many operations")
|
299
|
-
|
300
|
-
# Include the parsed function name for visibility/debugging
|
301
|
-
fn_name = getattr(fn, "name", None) # type: ignore[attr-defined]
|
302
|
-
plan: Dict[str, Any] = {"version": 1, "function": fn_name, "ops": ops, "outputs": outputs}
|
303
|
-
if settings:
|
304
|
-
plan["settings"] = settings
|
305
|
-
return plan
|
306
|
-
|
307
|
-
# If a specific function name is provided, use it; otherwise try to auto-detect
|
308
|
-
if function_name is not None:
|
309
|
-
fn: Optional[ast.AST] = None
|
310
|
-
for node in module.body:
|
311
|
-
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == function_name:
|
312
|
-
fn = node
|
313
|
-
break
|
314
|
-
if fn is None:
|
315
|
-
raise DSLParseError(f"Function {function_name!r} not found")
|
316
|
-
return _parse_fn(fn)
|
317
|
-
else:
|
318
|
-
last_err: Optional[Exception] = None
|
319
|
-
for node in module.body:
|
320
|
-
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
321
|
-
try:
|
322
|
-
return _parse_fn(node)
|
323
|
-
except DSLParseError as e:
|
324
|
-
last_err = e
|
325
|
-
continue
|
326
|
-
# If we got here, either there are no functions or none matched the DSL
|
327
|
-
if last_err is not None:
|
328
|
-
raise DSLParseError("No suitable function matched the DSL; specify --func to disambiguate") from last_err
|
329
|
-
raise DSLParseError("No function definitions found in source")
|
330
|
-
|
331
|
-
|
332
|
-
def parse_file(filename: str, function_name: Optional[str] = None) -> Dict[str, Any]:
|
333
|
-
with open(filename, "r", encoding="utf-8") as f:
|
334
|
-
src = f.read()
|
335
|
-
return parse(src, function_name=function_name)
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|