py2dag 0.1.14__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: py2dag
3
- Version: 0.1.14
3
+ Version: 0.2.0
4
4
  Summary: Convert Python function plans to DAG (JSON, pseudo, optional SVG).
5
5
  License: MIT
6
6
  Author: rvergis
@@ -48,9 +48,22 @@ HTML_TEMPLATE = """<!doctype html>
48
48
  // Ensure edges have an object for labels/attrs to avoid TypeErrors
49
49
  g.setDefaultEdgeLabel(() => ({}));
50
50
 
51
- // Add op nodes
51
+ // Add op nodes with basic styling for control nodes
52
52
  (plan.ops || []).forEach(op => {
53
- g.setNode(op.id, { label: op.op, class: 'op', padding: 8 });
53
+ let label = op.op;
54
+ let klass = 'op';
55
+ if (op.op === 'COND.eval') {
56
+ const kind = (op.args && op.args.kind) || 'if';
57
+ label = (kind.toUpperCase()) + ' ' + (op.args && op.args.expr ? op.args.expr : '');
58
+ klass = 'note';
59
+ } else if (op.op === 'ITER.eval') {
60
+ label = 'FOR ' + (op.args && op.args.expr ? op.args.expr : '');
61
+ klass = 'note';
62
+ } else if (op.op === 'PHI') {
63
+ label = 'PHI' + (op.args && op.args.var ? ` (${op.args.var})` : '');
64
+ klass = 'note';
65
+ }
66
+ g.setNode(op.id, { label, class: klass, padding: 8 });
54
67
  });
55
68
 
56
69
  // Add output nodes and edges from source to output
@@ -0,0 +1,500 @@
1
+ import ast
2
+ import json
3
+ import re
4
+ from typing import Any, Dict, List, Optional, Tuple, Set
5
+
6
+ VALID_NAME_RE = re.compile(r'^[a-z_][a-z0-9_]{0,63}$')
7
+
8
+
9
+ class DSLParseError(Exception):
10
+ """Raised when the mini-DSL constraints are violated."""
11
+
12
+
13
+ def _literal(node: ast.AST) -> Any:
14
+ """Return a Python literal from an AST node or raise DSLParseError."""
15
+ if isinstance(node, ast.Constant):
16
+ return node.value
17
+ if isinstance(node, (ast.List, ast.Tuple)):
18
+ return [_literal(elt) for elt in node.elts]
19
+ if isinstance(node, ast.Dict):
20
+ return {_literal(k): _literal(v) for k, v in zip(node.keys, node.values)}
21
+ raise DSLParseError("Keyword argument values must be JSON-serialisable literals")
22
+
23
+
24
+ def _get_call_name(func: ast.AST) -> str:
25
+ if isinstance(func, ast.Name):
26
+ return func.id
27
+ if isinstance(func, ast.Attribute):
28
+ parts: List[str] = []
29
+ while isinstance(func, ast.Attribute):
30
+ parts.append(func.attr)
31
+ func = func.value
32
+ if isinstance(func, ast.Name):
33
+ parts.append(func.id)
34
+ return ".".join(reversed(parts))
35
+ raise DSLParseError("Only simple or attribute names are allowed for operations")
36
+
37
+
38
+ def parse(source: str, function_name: Optional[str] = None) -> Dict[str, Any]:
39
+ if len(source) > 20_000:
40
+ raise DSLParseError("Source too large")
41
+ module = ast.parse(source)
42
+
43
+ def _parse_fn(fn: ast.AST) -> Dict[str, Any]:
44
+ ops: List[Dict[str, Any]] = []
45
+ outputs: List[Dict[str, str]] = []
46
+ settings: Dict[str, Any] = {}
47
+
48
+ returned_var: Optional[str] = None
49
+
50
+ # Enforce no-args top-level function signature
51
+ try:
52
+ fargs = getattr(fn, "args") # type: ignore[attr-defined]
53
+ has_params = bool(
54
+ getattr(fargs, "posonlyargs", []) or fargs.args or fargs.vararg or fargs.kwonlyargs or fargs.kwarg
55
+ )
56
+ if has_params:
57
+ raise DSLParseError("Top-level function must not accept parameters")
58
+ except AttributeError:
59
+ pass
60
+
61
+ # SSA state
62
+ versions: Dict[str, int] = {}
63
+ latest: Dict[str, str] = {}
64
+ context_suffix: str = ""
65
+ ctx_counts: Dict[str, int] = {"if": 0, "loop": 0, "while": 0}
66
+
67
+ def _ssa_new(name: str) -> str:
68
+ if not VALID_NAME_RE.match(name):
69
+ raise DSLParseError(f"Invalid variable name: {name}")
70
+ versions[name] = versions.get(name, 0) + 1
71
+ base = f"{name}_{versions[name]}"
72
+ ssa = f"{base}@{context_suffix}" if context_suffix else base
73
+ latest[name] = ssa
74
+ return ssa
75
+
76
+ def _ssa_get(name: str) -> str:
77
+ if name not in latest:
78
+ raise DSLParseError(f"Undefined dependency: {name}")
79
+ return latest[name]
80
+
81
+ def _collect_name_loads(node: ast.AST) -> List[str]:
82
+ names: List[str] = []
83
+ for n in ast.walk(node):
84
+ if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load):
85
+ if n.id not in names:
86
+ names.append(n.id)
87
+ return names
88
+
89
+ def _collect_value_deps(node: ast.AST) -> List[str]:
90
+ """Collect variable name dependencies from an expression, excluding callee names in Call.func.
91
+
92
+ For example, for range(n) -> ['n'] (not 'range'). For cond(a) -> ['a'] (not 'cond').
93
+ For obj.attr -> ['obj'].
94
+ """
95
+ callees: set[str] = set()
96
+
97
+ def mark_callee(func: ast.AST):
98
+ for n in ast.walk(func):
99
+ if isinstance(n, ast.Name):
100
+ callees.add(n.id)
101
+
102
+ # First collect callee name ids appearing under Call.func
103
+ for n in ast.walk(node):
104
+ if isinstance(n, ast.Call):
105
+ mark_callee(n.func)
106
+
107
+ # Then collect normal loads and drop any that are marked as callees
108
+ deps: List[str] = []
109
+ for n in ast.walk(node):
110
+ if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load):
111
+ if n.id not in callees and n.id not in deps:
112
+ deps.append(n.id)
113
+ return deps
114
+
115
+ def _stringify(node: ast.AST) -> str:
116
+ try:
117
+ return ast.unparse(node) # type: ignore[attr-defined]
118
+ except Exception:
119
+ return node.__class__.__name__
120
+
121
+ def _emit_assign_from_call(var_name: str, call: ast.Call) -> str:
122
+ op_name = _get_call_name(call.func)
123
+ deps: List[str] = []
124
+
125
+ def _expand_star_name(ssa_var: str) -> List[str]:
126
+ # Expand if previous op was a PACK.*
127
+ for prev in reversed(ops):
128
+ if prev.get("id") == ssa_var and prev.get("op") in {"PACK.list", "PACK.tuple"}:
129
+ return list(prev.get("deps", []))
130
+ return [ssa_var]
131
+
132
+ for arg in call.args:
133
+ if isinstance(arg, ast.Starred):
134
+ star_val = arg.value
135
+ if isinstance(star_val, ast.Name):
136
+ deps.extend(_expand_star_name(_ssa_get(star_val.id)))
137
+ elif isinstance(star_val, (ast.List, ast.Tuple)):
138
+ for elt in star_val.elts:
139
+ if not isinstance(elt, ast.Name):
140
+ raise DSLParseError("Starred list/tuple elements must be names")
141
+ deps.append(_ssa_get(elt.id))
142
+ else:
143
+ raise DSLParseError("*args must be a name or list/tuple of names")
144
+ elif isinstance(arg, ast.Name):
145
+ deps.append(_ssa_get(arg.id))
146
+ elif isinstance(arg, (ast.List, ast.Tuple)):
147
+ for elt in arg.elts:
148
+ if not isinstance(elt, ast.Name):
149
+ raise DSLParseError("List/Tuple positional args must be variable names")
150
+ deps.append(_ssa_get(elt.id))
151
+ else:
152
+ raise DSLParseError("Positional args must be variable names or lists/tuples of names")
153
+
154
+ kwargs: Dict[str, Any] = {}
155
+ for kw in call.keywords:
156
+ if kw.arg is None:
157
+ v = kw.value
158
+ if isinstance(v, ast.Dict):
159
+ lit = _literal(v)
160
+ for k, val in lit.items():
161
+ kwargs[str(k)] = val
162
+ elif isinstance(v, ast.Name):
163
+ deps.append(_ssa_get(v.id))
164
+ else:
165
+ raise DSLParseError("**kwargs must be a dict literal or a variable name")
166
+ else:
167
+ if isinstance(kw.value, ast.Name):
168
+ deps.append(_ssa_get(kw.value.id))
169
+ else:
170
+ kwargs[kw.arg] = _literal(kw.value)
171
+
172
+ ssa = _ssa_new(var_name)
173
+ ops.append({"id": ssa, "op": op_name, "deps": deps, "args": kwargs})
174
+ return ssa
175
+
176
+ def _emit_assign_from_fstring(var_name: str, fstr: ast.JoinedStr) -> str:
177
+ deps: List[str] = []
178
+ parts: List[str] = []
179
+ for item in fstr.values:
180
+ if isinstance(item, ast.Constant) and isinstance(item.value, str):
181
+ parts.append(item.value)
182
+ elif isinstance(item, ast.FormattedValue) and isinstance(item.value, ast.Name):
183
+ deps.append(_ssa_get(item.value.id))
184
+ parts.append("{" + str(len(deps) - 1) + "}")
185
+ else:
186
+ raise DSLParseError("f-strings may only contain variable names")
187
+ template = "".join(parts)
188
+ ssa = _ssa_new(var_name)
189
+ ops.append({
190
+ "id": ssa,
191
+ "op": "TEXT.format",
192
+ "deps": deps,
193
+ "args": {"template": template},
194
+ })
195
+ return ssa
196
+
197
+ def _emit_assign_from_literal_or_pack(var_name: str, value: ast.AST) -> str:
198
+ try:
199
+ lit = _literal(value)
200
+ ssa = _ssa_new(var_name)
201
+ ops.append({
202
+ "id": ssa,
203
+ "op": "CONST.value",
204
+ "deps": [],
205
+ "args": {"value": lit},
206
+ })
207
+ return ssa
208
+ except DSLParseError:
209
+ if isinstance(value, (ast.List, ast.Tuple)):
210
+ elts = value.elts
211
+ deps: List[str] = []
212
+ for elt in elts:
213
+ if not isinstance(elt, ast.Name):
214
+ raise DSLParseError("Only names allowed in non-literal list/tuple assignment")
215
+ deps.append(_ssa_get(elt.id))
216
+ kind = "list" if isinstance(value, ast.List) else "tuple"
217
+ ssa = _ssa_new(var_name)
218
+ ops.append({
219
+ "id": ssa,
220
+ "op": f"PACK.{kind}",
221
+ "deps": deps,
222
+ "args": {},
223
+ })
224
+ return ssa
225
+ raise
226
+
227
+ def _emit_assign_from_comp(var_name: str, node: ast.AST) -> str:
228
+ name_deps = [n for n in _collect_name_loads(node) if n in latest]
229
+ for n in name_deps:
230
+ if n not in latest:
231
+ raise DSLParseError(f"Undefined dependency: {n}")
232
+ kind = (
233
+ "listcomp" if isinstance(node, ast.ListComp) else
234
+ "setcomp" if isinstance(node, ast.SetComp) else
235
+ "dictcomp" if isinstance(node, ast.DictComp) else
236
+ "genexpr"
237
+ )
238
+ deps = [_ssa_get(n) for n in name_deps]
239
+ ssa = _ssa_new(var_name)
240
+ ops.append({
241
+ "id": ssa,
242
+ "op": f"COMP.{kind}",
243
+ "deps": deps,
244
+ "args": {},
245
+ })
246
+ return ssa
247
+
248
+ def _emit_cond(node: ast.AST, kind: str = "if") -> str:
249
+ expr = _stringify(node)
250
+ deps = [_ssa_get(n) for n in _collect_value_deps(node)]
251
+ ssa = _ssa_new("cond")
252
+ ops.append({"id": ssa, "op": "COND.eval", "deps": deps, "args": {"expr": expr, "kind": kind}})
253
+ return ssa
254
+
255
+ def _emit_iter(node: ast.AST) -> str:
256
+ expr = _stringify(node)
257
+ deps = [_ssa_get(n) for n in _collect_value_deps(node)]
258
+ ssa = _ssa_new("iter")
259
+ ops.append({"id": ssa, "op": "ITER.eval", "deps": deps, "args": {"expr": expr, "kind": "for"}})
260
+ return ssa
261
+
262
+ def _parse_stmt(stmt: ast.stmt) -> Optional[str]:
263
+ nonlocal returned_var, versions, latest, context_suffix
264
+ if isinstance(stmt, ast.Assign):
265
+ if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name):
266
+ raise DSLParseError("Assignment targets must be simple names")
267
+ var_name = stmt.targets[0].id
268
+ value = stmt.value
269
+ if isinstance(value, ast.Await):
270
+ value = value.value
271
+ if isinstance(value, ast.Call):
272
+ return _emit_assign_from_call(var_name, value)
273
+ elif isinstance(value, ast.JoinedStr):
274
+ return _emit_assign_from_fstring(var_name, value)
275
+ elif isinstance(value, (ast.Constant, ast.List, ast.Tuple, ast.Dict)):
276
+ return _emit_assign_from_literal_or_pack(var_name, value)
277
+ elif isinstance(value, (ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)):
278
+ return _emit_assign_from_comp(var_name, value)
279
+ else:
280
+ raise DSLParseError("Right hand side must be a call or f-string")
281
+ elif isinstance(stmt, ast.Expr):
282
+ call = stmt.value
283
+ if isinstance(call, ast.Await):
284
+ call = call.value
285
+ if not isinstance(call, ast.Call):
286
+ raise DSLParseError("Only call expressions allowed at top level")
287
+ name = _get_call_name(call.func)
288
+ if name == "settings":
289
+ for kw in call.keywords:
290
+ if kw.arg is None:
291
+ raise DSLParseError("settings does not accept **kwargs")
292
+ settings[kw.arg] = _literal(kw.value)
293
+ if call.args:
294
+ raise DSLParseError("settings only accepts keyword literals")
295
+ elif name == "output":
296
+ if len(call.args) != 1 or not isinstance(call.args[0], ast.Name):
297
+ raise DSLParseError("output requires a single variable name argument")
298
+ var = call.args[0].id
299
+ ssa_from = _ssa_get(var)
300
+ filename = None
301
+ for kw in call.keywords:
302
+ if kw.arg in {"as", "as_"}:
303
+ filename = _literal(kw.value)
304
+ else:
305
+ raise DSLParseError("output only accepts 'as' keyword")
306
+ if filename is None or not isinstance(filename, str):
307
+ raise DSLParseError("output requires as=\"filename\"")
308
+ outputs.append({"from": ssa_from, "as": filename})
309
+ else:
310
+ raise DSLParseError("Only settings() and output() calls allowed as expressions")
311
+ return None
312
+ elif isinstance(stmt, ast.Return):
313
+ if isinstance(stmt.value, ast.Name):
314
+ returned_var = _ssa_get(stmt.value.id)
315
+ elif isinstance(stmt.value, (ast.Constant, ast.List, ast.Tuple, ast.Dict)):
316
+ lit = _literal(stmt.value)
317
+ const_id = _ssa_new("return_value")
318
+ ops.append({
319
+ "id": const_id,
320
+ "op": "CONST.value",
321
+ "deps": [],
322
+ "args": {"value": lit},
323
+ })
324
+ returned_var = const_id
325
+ else:
326
+ raise DSLParseError("return must return a variable name or literal")
327
+ return None
328
+ elif isinstance(stmt, ast.If):
329
+ # Evaluate condition
330
+ cond_id = _emit_cond(stmt.test, kind="if")
331
+ # Save pre-branch state
332
+ pre_versions = dict(versions)
333
+ pre_latest = dict(latest)
334
+
335
+ # THEN branch
336
+ then_ops_start = len(ops)
337
+ versions_then = dict(pre_versions)
338
+ latest_then = dict(pre_latest)
339
+ # Run then body with local state and context
340
+ saved_versions, saved_latest = versions, latest
341
+ saved_ctx = context_suffix
342
+ ctx_counts["if"] += 1
343
+ context_suffix = f"then{ctx_counts['if']}"
344
+ versions, latest = versions_then, latest_then
345
+ for inner in stmt.body:
346
+ _parse_stmt(inner)
347
+ versions_then, latest_then = versions, latest
348
+ versions, latest = saved_versions, saved_latest
349
+ context_suffix = saved_ctx
350
+
351
+ # ELSE branch
352
+ else_ops_start = len(ops)
353
+ versions_else = dict(pre_versions)
354
+ latest_else = dict(pre_latest)
355
+ saved_versions, saved_latest = versions, latest
356
+ saved_ctx = context_suffix
357
+ context_suffix = f"else{ctx_counts['if']}"
358
+ versions, latest = versions_else, latest_else
359
+ for inner in stmt.orelse or []:
360
+ _parse_stmt(inner)
361
+ versions_else, latest_else = versions, latest
362
+ versions, latest = saved_versions, saved_latest
363
+ context_suffix = saved_ctx
364
+
365
+ # Add cond dep to first op in each branch, if any
366
+ if len(ops) > then_ops_start:
367
+ ops[then_ops_start]["deps"] = [*ops[then_ops_start].get("deps", []), cond_id]
368
+ if len(ops) > else_ops_start:
369
+ ops[else_ops_start]["deps"] = [*ops[else_ops_start].get("deps", []), cond_id]
370
+
371
+ # Determine variables assigned in branches
372
+ then_assigned = {k for k in latest_then if pre_latest.get(k) != latest_then.get(k)}
373
+ else_assigned = {k for k in latest_else if pre_latest.get(k) != latest_else.get(k)}
374
+ all_assigned = then_assigned | else_assigned
375
+ for var in sorted(all_assigned):
376
+ left = latest_then.get(var, pre_latest.get(var))
377
+ right = latest_else.get(var, pre_latest.get(var))
378
+ if left is None or right is None:
379
+ # Variable does not exist pre-branch on one side; skip making it available post-merge
380
+ continue
381
+ phi_id = _ssa_new(var)
382
+ ops.append({"id": phi_id, "op": "PHI", "deps": [left, right], "args": {"var": var}})
383
+ return None
384
+ elif isinstance(stmt, (ast.For, ast.AsyncFor)):
385
+ # ITER over iterable
386
+ iter_id = _emit_iter(stmt.iter)
387
+ # Save pre-loop state
388
+ pre_versions = dict(versions)
389
+ pre_latest = dict(latest)
390
+ # Body state copy
391
+ body_ops_start = len(ops)
392
+ versions_body = dict(pre_versions)
393
+ latest_body = dict(pre_latest)
394
+ saved_versions, saved_latest = versions, latest
395
+ saved_ctx = context_suffix
396
+ ctx_counts["loop"] += 1
397
+ context_suffix = f"loop{ctx_counts['loop']}"
398
+ versions, latest = versions_body, latest_body
399
+ for inner in stmt.body:
400
+ _parse_stmt(inner)
401
+ versions_body, latest_body = versions, latest
402
+ versions, latest = saved_versions, saved_latest
403
+ context_suffix = saved_ctx
404
+ # Add iter dep to first op in body
405
+ if len(ops) > body_ops_start:
406
+ ops[body_ops_start]["deps"] = [*ops[body_ops_start].get("deps", []), iter_id]
407
+ # Loop-carried vars: only those existing pre-loop and reassigned in body
408
+ changed = {k for k in latest_body if pre_latest.get(k) != latest_body.get(k)}
409
+ carried = [k for k in changed if k in pre_latest]
410
+ for var in sorted(carried):
411
+ phi_id = _ssa_new(var)
412
+ ops.append({
413
+ "id": phi_id,
414
+ "op": "PHI",
415
+ "deps": [pre_latest[var], latest_body[var]],
416
+ "args": {"var": var},
417
+ })
418
+ return None
419
+ elif isinstance(stmt, ast.While):
420
+ cond_id = _emit_cond(stmt.test, kind="while")
421
+ pre_versions = dict(versions)
422
+ pre_latest = dict(latest)
423
+ body_ops_start = len(ops)
424
+ versions_body = dict(pre_versions)
425
+ latest_body = dict(pre_latest)
426
+ saved_versions, saved_latest = versions, latest
427
+ saved_ctx = context_suffix
428
+ ctx_counts["while"] += 1
429
+ context_suffix = f"while{ctx_counts['while']}"
430
+ versions, latest = versions_body, latest_body
431
+ for inner in stmt.body:
432
+ _parse_stmt(inner)
433
+ versions_body, latest_body = versions, latest
434
+ versions, latest = saved_versions, saved_latest
435
+ context_suffix = saved_ctx
436
+ if len(ops) > body_ops_start:
437
+ ops[body_ops_start]["deps"] = [*ops[body_ops_start].get("deps", []), cond_id]
438
+ changed = {k for k in latest_body if pre_latest.get(k) != latest_body.get(k)}
439
+ carried = [k for k in changed if k in pre_latest]
440
+ for var in sorted(carried):
441
+ phi_id = _ssa_new(var)
442
+ ops.append({
443
+ "id": phi_id,
444
+ "op": "PHI",
445
+ "deps": [pre_latest[var], latest_body[var]],
446
+ "args": {"var": var},
447
+ })
448
+ return None
449
+ elif isinstance(stmt, (ast.Pass,)):
450
+ return None
451
+ else:
452
+ raise DSLParseError("Only assignments, control flow, settings/output calls, and return are allowed in function body")
453
+
454
+ # Parse body sequentially; still require a resulting output
455
+ for i, stmt in enumerate(fn.body): # type: ignore[attr-defined]
456
+ _parse_stmt(stmt)
457
+
458
+ if not outputs:
459
+ if returned_var is not None:
460
+ outputs.append({"from": returned_var, "as": "return"})
461
+ else:
462
+ raise DSLParseError("At least one output() call required")
463
+ if len(ops) > 2000:
464
+ raise DSLParseError("Too many operations")
465
+
466
+ fn_name = getattr(fn, "name", None) # type: ignore[attr-defined]
467
+ plan: Dict[str, Any] = {"version": 2, "function": fn_name, "ops": ops, "outputs": outputs}
468
+ if settings:
469
+ plan["settings"] = settings
470
+ return plan
471
+
472
+ # If a specific function name is provided, use it; otherwise try to auto-detect
473
+ if function_name is not None:
474
+ fn: Optional[ast.AST] = None
475
+ for node in module.body:
476
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == function_name:
477
+ fn = node
478
+ break
479
+ if fn is None:
480
+ raise DSLParseError(f"Function {function_name!r} not found")
481
+ return _parse_fn(fn)
482
+ else:
483
+ last_err: Optional[Exception] = None
484
+ for node in module.body:
485
+ if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
486
+ try:
487
+ return _parse_fn(node)
488
+ except DSLParseError as e:
489
+ last_err = e
490
+ continue
491
+ # If we got here, either there are no functions or none matched the DSL
492
+ if last_err is not None:
493
+ raise DSLParseError("No suitable function matched the DSL; specify --func to disambiguate") from last_err
494
+ raise DSLParseError("No function definitions found in source")
495
+
496
+
497
+ def parse_file(filename: str, function_name: Optional[str] = None) -> Dict[str, Any]:
498
+ with open(filename, "r", encoding="utf-8") as f:
499
+ src = f.read()
500
+ return parse(src, function_name=function_name)
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "py2dag"
7
- version = "0.1.14"
7
+ version = "0.2.0"
8
8
  description = "Convert Python function plans to DAG (JSON, pseudo, optional SVG)."
9
9
  authors = ["rvergis"]
10
10
  license = "MIT"
@@ -1,330 +0,0 @@
1
- import ast
2
- import json
3
- import re
4
- from typing import Any, Dict, List, Optional
5
-
6
- VALID_NAME_RE = re.compile(r'^[a-z_][a-z0-9_]{0,63}$')
7
-
8
-
9
- class DSLParseError(Exception):
10
- """Raised when the mini-DSL constraints are violated."""
11
-
12
-
13
- def _literal(node: ast.AST) -> Any:
14
- """Return a Python literal from an AST node or raise DSLParseError."""
15
- if isinstance(node, ast.Constant):
16
- return node.value
17
- if isinstance(node, (ast.List, ast.Tuple)):
18
- return [_literal(elt) for elt in node.elts]
19
- if isinstance(node, ast.Dict):
20
- return {_literal(k): _literal(v) for k, v in zip(node.keys, node.values)}
21
- raise DSLParseError("Keyword argument values must be JSON-serialisable literals")
22
-
23
-
24
- def _get_call_name(func: ast.AST) -> str:
25
- if isinstance(func, ast.Name):
26
- return func.id
27
- if isinstance(func, ast.Attribute):
28
- parts: List[str] = []
29
- while isinstance(func, ast.Attribute):
30
- parts.append(func.attr)
31
- func = func.value
32
- if isinstance(func, ast.Name):
33
- parts.append(func.id)
34
- return ".".join(reversed(parts))
35
- raise DSLParseError("Only simple or attribute names are allowed for operations")
36
-
37
-
38
- def parse(source: str, function_name: Optional[str] = None) -> Dict[str, Any]:
39
- if len(source) > 20_000:
40
- raise DSLParseError("Source too large")
41
- module = ast.parse(source)
42
-
43
- def _parse_fn(fn: ast.AST) -> Dict[str, Any]:
44
- defined: set[str] = set()
45
- ops: List[Dict[str, Any]] = []
46
- outputs: List[Dict[str, str]] = []
47
- settings: Dict[str, Any] = {}
48
-
49
- returned_var: Optional[str] = None
50
- # Treat function parameters as pre-defined names
51
- try:
52
- for arg in getattr(fn, "args").args: # type: ignore[attr-defined]
53
- defined.add(arg.arg)
54
- except Exception:
55
- pass
56
-
57
- def _collect_name_deps(node: ast.AST) -> List[str]:
58
- names: List[str] = []
59
- for n in ast.walk(node):
60
- if isinstance(n, ast.Name) and isinstance(n.ctx, ast.Load):
61
- if n.id not in names:
62
- names.append(n.id)
63
- return names
64
- # type: ignore[attr-defined]
65
- for i, stmt in enumerate(fn.body): # type: ignore[attr-defined]
66
- if isinstance(stmt, ast.Assign):
67
- if len(stmt.targets) != 1 or not isinstance(stmt.targets[0], ast.Name):
68
- raise DSLParseError("Assignment targets must be simple names")
69
- var_name = stmt.targets[0].id
70
- if not VALID_NAME_RE.match(var_name):
71
- raise DSLParseError(f"Invalid variable name: {var_name}")
72
- if var_name in defined:
73
- raise DSLParseError(f"Duplicate variable name: {var_name}")
74
-
75
- value = stmt.value
76
- if isinstance(value, ast.Await):
77
- value = value.value
78
- if isinstance(value, ast.Call):
79
- op_name = _get_call_name(value.func)
80
-
81
- deps: List[str] = []
82
-
83
- def _expand_star_name(varname: str) -> List[str]:
84
- # Try to expand a previously packed list/tuple variable into its element deps
85
- for prev in reversed(ops):
86
- if prev.get("id") == varname:
87
- if prev.get("op") in {"PACK.list", "PACK.tuple"}:
88
- return list(prev.get("deps", []))
89
- break
90
- return [varname]
91
- for arg in value.args:
92
- if isinstance(arg, ast.Starred):
93
- star_val = arg.value
94
- if isinstance(star_val, ast.Name):
95
- if star_val.id not in defined:
96
- raise DSLParseError(f"Undefined dependency: {star_val.id}")
97
- deps.extend(_expand_star_name(star_val.id))
98
- elif isinstance(star_val, (ast.List, ast.Tuple)):
99
- for elt in star_val.elts:
100
- if not isinstance(elt, ast.Name):
101
- raise DSLParseError("Starred list/tuple elements must be names")
102
- if elt.id not in defined:
103
- raise DSLParseError(f"Undefined dependency: {elt.id}")
104
- deps.append(elt.id)
105
- else:
106
- raise DSLParseError("*args must be a name or list/tuple of names")
107
- elif isinstance(arg, ast.Name):
108
- if arg.id not in defined:
109
- raise DSLParseError(f"Undefined dependency: {arg.id}")
110
- deps.append(arg.id)
111
- elif isinstance(arg, (ast.List, ast.Tuple)):
112
- for elt in arg.elts:
113
- if not isinstance(elt, ast.Name):
114
- raise DSLParseError("List/Tuple positional args must be variable names")
115
- if elt.id not in defined:
116
- raise DSLParseError(f"Undefined dependency: {elt.id}")
117
- deps.append(elt.id)
118
- else:
119
- raise DSLParseError("Positional args must be variable names or lists/tuples of names")
120
-
121
- kwargs: Dict[str, Any] = {}
122
- for kw in value.keywords:
123
- if kw.arg is None:
124
- # **kwargs support: allow dict literal merge, or variable name as dep
125
- v = kw.value
126
- if isinstance(v, ast.Dict):
127
- # Merge literal kwargs
128
- lit = _literal(v)
129
- for k, val in lit.items():
130
- kwargs[str(k)] = val
131
- elif isinstance(v, ast.Name):
132
- if v.id not in defined:
133
- raise DSLParseError(f"Undefined dependency: {v.id}")
134
- deps.append(v.id)
135
- else:
136
- raise DSLParseError("**kwargs must be a dict literal or a variable name")
137
- else:
138
- # Support variable-name keyword args as dependencies; literals remain in args
139
- if isinstance(kw.value, ast.Name):
140
- name = kw.value.id
141
- if name not in defined:
142
- raise DSLParseError(f"Undefined dependency: {name}")
143
- deps.append(name)
144
- else:
145
- kwargs[kw.arg] = _literal(kw.value)
146
-
147
- ops.append({"id": var_name, "op": op_name, "deps": deps, "args": kwargs})
148
- elif isinstance(value, ast.JoinedStr):
149
- # Minimal f-string support: only variable placeholders
150
- deps: List[str] = []
151
- parts: List[str] = []
152
- for item in value.values:
153
- if isinstance(item, ast.Constant) and isinstance(item.value, str):
154
- parts.append(item.value)
155
- elif isinstance(item, ast.FormattedValue) and isinstance(item.value, ast.Name):
156
- name = item.value.id
157
- if name not in defined:
158
- raise DSLParseError(f"Undefined dependency: {name}")
159
- deps.append(name)
160
- parts.append("{" + str(len(deps) - 1) + "}")
161
- else:
162
- raise DSLParseError("f-strings may only contain variable names")
163
- template = "".join(parts)
164
- ops.append({
165
- "id": var_name,
166
- "op": "TEXT.format",
167
- "deps": deps,
168
- "args": {"template": template},
169
- })
170
- elif isinstance(value, (ast.Constant, ast.List, ast.Tuple, ast.Dict)):
171
- # Allow assigning literals; also support packing lists/tuples of names
172
- try:
173
- lit = _literal(value)
174
- ops.append({
175
- "id": var_name,
176
- "op": "CONST.value",
177
- "deps": [],
178
- "args": {"value": lit},
179
- })
180
- except DSLParseError:
181
- if isinstance(value, (ast.List, ast.Tuple)):
182
- elts = value.elts
183
- names: List[str] = []
184
- for elt in elts:
185
- if not isinstance(elt, ast.Name):
186
- raise DSLParseError("Only names allowed in non-literal list/tuple assignment")
187
- if elt.id not in defined:
188
- raise DSLParseError(f"Undefined dependency: {elt.id}")
189
- names.append(elt.id)
190
- kind = "list" if isinstance(value, ast.List) else "tuple"
191
- ops.append({
192
- "id": var_name,
193
- "op": f"PACK.{kind}",
194
- "deps": names,
195
- "args": {},
196
- })
197
- else:
198
- raise
199
- elif isinstance(value, (ast.ListComp, ast.SetComp, ast.DictComp, ast.GeneratorExp)):
200
- # Basic comprehension support: collect name deps and emit a generic comp op
201
- name_deps = [n for n in _collect_name_deps(value) if n in defined]
202
- # Ensure no undefined names used
203
- for n in name_deps:
204
- if n not in defined:
205
- raise DSLParseError(f"Undefined dependency: {n}")
206
- kind = (
207
- "listcomp" if isinstance(value, ast.ListComp) else
208
- "setcomp" if isinstance(value, ast.SetComp) else
209
- "dictcomp" if isinstance(value, ast.DictComp) else
210
- "genexpr"
211
- )
212
- ops.append({
213
- "id": var_name,
214
- "op": f"COMP.{kind}",
215
- "deps": name_deps,
216
- "args": {},
217
- })
218
- else:
219
- raise DSLParseError("Right hand side must be a call or f-string")
220
- defined.add(var_name)
221
-
222
- elif isinstance(stmt, ast.Expr):
223
- call = stmt.value
224
- if isinstance(call, ast.Await):
225
- call = call.value
226
- if not isinstance(call, ast.Call):
227
- raise DSLParseError("Only call expressions allowed at top level")
228
- name = _get_call_name(call.func)
229
- if name == "settings":
230
- for kw in call.keywords:
231
- if kw.arg is None:
232
- raise DSLParseError("settings does not accept **kwargs")
233
- settings[kw.arg] = _literal(kw.value)
234
- if call.args:
235
- raise DSLParseError("settings only accepts keyword literals")
236
- elif name == "output":
237
- if len(call.args) != 1 or not isinstance(call.args[0], ast.Name):
238
- raise DSLParseError("output requires a single variable name argument")
239
- var = call.args[0].id
240
- if var not in defined:
241
- raise DSLParseError(f"Undefined output variable: {var}")
242
- filename = None
243
- for kw in call.keywords:
244
- if kw.arg in {"as", "as_"}:
245
- filename = _literal(kw.value)
246
- else:
247
- raise DSLParseError("output only accepts 'as' keyword")
248
- if filename is None or not isinstance(filename, str):
249
- raise DSLParseError("output requires as=\"filename\"")
250
- outputs.append({"from": var, "as": filename})
251
- else:
252
- raise DSLParseError("Only settings() and output() calls allowed as expressions")
253
- elif isinstance(stmt, ast.Return):
254
- if i != len(fn.body) - 1: # type: ignore[index]
255
- raise DSLParseError("return must be the last statement")
256
- if isinstance(stmt.value, ast.Name):
257
- var = stmt.value.id
258
- if var not in defined:
259
- raise DSLParseError(f"Undefined return variable: {var}")
260
- returned_var = var
261
- elif isinstance(stmt.value, (ast.Constant, ast.List, ast.Tuple, ast.Dict)):
262
- # Support returning a JSON-serialisable literal (str/num/bool/None, list/tuple, dict)
263
- lit = _literal(stmt.value)
264
- const_id_base = "return_value"
265
- const_id = const_id_base
266
- n = 1
267
- while const_id in defined:
268
- const_id = f"{const_id_base}_{n}"
269
- n += 1
270
- ops.append({
271
- "id": const_id,
272
- "op": "CONST.value",
273
- "deps": [],
274
- "args": {"value": lit},
275
- })
276
- returned_var = const_id
277
- else:
278
- raise DSLParseError("return must return a variable name or literal")
279
- elif isinstance(stmt, (ast.For, ast.AsyncFor, ast.While, ast.If, ast.Match)):
280
- # Ignore control flow blocks; only top-level linear statements are modeled
281
- continue
282
- elif isinstance(stmt, (ast.Pass,)):
283
- continue
284
- else:
285
- raise DSLParseError("Only assignments, expression calls, and a final return are allowed in function body")
286
-
287
- if not outputs:
288
- if returned_var is not None:
289
- outputs.append({"from": returned_var, "as": "return"})
290
- else:
291
- raise DSLParseError("At least one output() call required")
292
- if len(ops) > 200:
293
- raise DSLParseError("Too many operations")
294
-
295
- # Include the parsed function name for visibility/debugging
296
- fn_name = getattr(fn, "name", None) # type: ignore[attr-defined]
297
- plan: Dict[str, Any] = {"version": 1, "function": fn_name, "ops": ops, "outputs": outputs}
298
- if settings:
299
- plan["settings"] = settings
300
- return plan
301
-
302
- # If a specific function name is provided, use it; otherwise try to auto-detect
303
- if function_name is not None:
304
- fn: Optional[ast.AST] = None
305
- for node in module.body:
306
- if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name == function_name:
307
- fn = node
308
- break
309
- if fn is None:
310
- raise DSLParseError(f"Function {function_name!r} not found")
311
- return _parse_fn(fn)
312
- else:
313
- last_err: Optional[Exception] = None
314
- for node in module.body:
315
- if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
316
- try:
317
- return _parse_fn(node)
318
- except DSLParseError as e:
319
- last_err = e
320
- continue
321
- # If we got here, either there are no functions or none matched the DSL
322
- if last_err is not None:
323
- raise DSLParseError("No suitable function matched the DSL; specify --func to disambiguate") from last_err
324
- raise DSLParseError("No function definitions found in source")
325
-
326
-
327
- def parse_file(filename: str, function_name: Optional[str] = None) -> Dict[str, Any]:
328
- with open(filename, "r", encoding="utf-8") as f:
329
- src = f.read()
330
- return parse(src, function_name=function_name)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes