rappel 0.4.5__py3-none-win_amd64.whl → 0.8.1__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rappel might be problematic. Click here for more details.
- proto/ast_pb2.py +79 -75
- proto/ast_pb2.pyi +155 -135
- proto/messages_pb2.py +49 -49
- proto/messages_pb2.pyi +77 -3
- rappel/__init__.py +6 -1
- rappel/actions.py +1 -1
- rappel/bin/boot-rappel-singleton.exe +0 -0
- rappel/bin/rappel-bridge.exe +0 -0
- rappel/bin/start-workers.exe +0 -0
- rappel/bridge.py +48 -48
- rappel/exceptions.py +7 -0
- rappel/ir_builder.py +1177 -453
- rappel/registry.py +5 -0
- rappel/schedule.py +80 -12
- rappel/serialization.py +75 -1
- rappel/workflow.py +32 -13
- rappel/workflow_runtime.py +156 -6
- rappel-0.8.1.data/scripts/boot-rappel-singleton.exe +0 -0
- {rappel-0.4.5.data → rappel-0.8.1.data}/scripts/rappel-bridge.exe +0 -0
- {rappel-0.4.5.data → rappel-0.8.1.data}/scripts/start-workers.exe +0 -0
- {rappel-0.4.5.dist-info → rappel-0.8.1.dist-info}/METADATA +13 -1
- rappel-0.8.1.dist-info/RECORD +32 -0
- rappel-0.4.5.data/scripts/boot-rappel-singleton.exe +0 -0
- rappel-0.4.5.dist-info/RECORD +0 -32
- {rappel-0.4.5.dist-info → rappel-0.8.1.dist-info}/WHEEL +0 -0
- {rappel-0.4.5.dist-info → rappel-0.8.1.dist-info}/entry_points.txt +0 -0
rappel/ir_builder.py
CHANGED
|
@@ -5,15 +5,10 @@ This module parses Python workflow classes and produces the IR representation
|
|
|
5
5
|
that can be sent to the Rust runtime for execution.
|
|
6
6
|
|
|
7
7
|
The IR builder performs deep transformations to convert Python patterns into
|
|
8
|
-
valid Rappel IR structures.
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
Transformations:
|
|
13
|
-
1. **Try body wrapping**: Wraps multi-action try bodies into synthetic functions
|
|
14
|
-
2. **For loop body wrapping**: Wraps multi-action for bodies into synthetic functions
|
|
15
|
-
3. **If branch wrapping**: Wraps multi-action if/elif/else branches into synthetic functions
|
|
16
|
-
4. **Exception handler wrapping**: Wraps multi-action handlers into synthetic functions
|
|
8
|
+
valid Rappel IR structures. Control flow bodies (try, for, if branches) are
|
|
9
|
+
represented as full blocks, so they can contain arbitrary statements and
|
|
10
|
+
`return` behaves consistently with Python (returns from the entry function end
|
|
11
|
+
the workflow).
|
|
17
12
|
|
|
18
13
|
Validation:
|
|
19
14
|
The IR builder proactively detects unsupported Python patterns and raises
|
|
@@ -27,9 +22,11 @@ import copy
|
|
|
27
22
|
import inspect
|
|
28
23
|
import textwrap
|
|
29
24
|
from dataclasses import dataclass
|
|
30
|
-
from
|
|
25
|
+
from enum import EnumMeta
|
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, NoReturn, Optional, Set, Union
|
|
31
27
|
|
|
32
28
|
from proto import ast_pb2 as ir
|
|
29
|
+
from rappel.registry import registry
|
|
33
30
|
|
|
34
31
|
|
|
35
32
|
class UnsupportedPatternError(Exception):
|
|
@@ -45,13 +42,22 @@ class UnsupportedPatternError(Exception):
|
|
|
45
42
|
recommendation: str,
|
|
46
43
|
line: Optional[int] = None,
|
|
47
44
|
col: Optional[int] = None,
|
|
45
|
+
filename: Optional[str] = None,
|
|
48
46
|
):
|
|
49
47
|
self.message = message
|
|
50
48
|
self.recommendation = recommendation
|
|
51
49
|
self.line = line
|
|
52
50
|
self.col = col
|
|
53
|
-
|
|
54
|
-
|
|
51
|
+
self.filename = filename
|
|
52
|
+
|
|
53
|
+
location_parts: List[str] = []
|
|
54
|
+
if filename:
|
|
55
|
+
location_parts.append(filename)
|
|
56
|
+
if line:
|
|
57
|
+
location_parts.append(f"line {line}")
|
|
58
|
+
if col is not None:
|
|
59
|
+
location_parts.append(f"col {col}")
|
|
60
|
+
location = f" ({', '.join(location_parts)})" if location_parts else ""
|
|
55
61
|
full_message = f"{message}{location}\n\nRecommendation: {recommendation}"
|
|
56
62
|
super().__init__(full_message)
|
|
57
63
|
|
|
@@ -223,8 +229,34 @@ RECOMMENDATIONS = {
|
|
|
223
229
|
"Yield statements are not supported in workflow code.\n"
|
|
224
230
|
"Workflows must return a complete result, not generate values incrementally."
|
|
225
231
|
),
|
|
232
|
+
"continue_statement": (
|
|
233
|
+
"Continue statements are not supported in workflow code.\n"
|
|
234
|
+
"Restructure your loop using if/else to skip iterations."
|
|
235
|
+
),
|
|
236
|
+
"unsupported_statement": (
|
|
237
|
+
"This statement type is not supported in workflow code.\n"
|
|
238
|
+
"Move the logic into an @action or rewrite using supported statements."
|
|
239
|
+
),
|
|
240
|
+
"unsupported_expression": (
|
|
241
|
+
"This expression type is not supported in workflow code.\n"
|
|
242
|
+
"Move the logic into an @action or rewrite using supported expressions."
|
|
243
|
+
),
|
|
244
|
+
"unsupported_literal": (
|
|
245
|
+
"This literal type is not supported in workflow code.\n"
|
|
246
|
+
"Convert the value to a supported literal type inside an @action."
|
|
247
|
+
),
|
|
226
248
|
}
|
|
227
249
|
|
|
250
|
+
GLOBAL_FUNCTIONS = {
|
|
251
|
+
"enumerate": ir.GlobalFunction.GLOBAL_FUNCTION_ENUMERATE,
|
|
252
|
+
"isexception": ir.GlobalFunction.GLOBAL_FUNCTION_ISEXCEPTION,
|
|
253
|
+
"len": ir.GlobalFunction.GLOBAL_FUNCTION_LEN,
|
|
254
|
+
"range": ir.GlobalFunction.GLOBAL_FUNCTION_RANGE,
|
|
255
|
+
}
|
|
256
|
+
ALLOWED_SYNC_FUNCTIONS = set(GLOBAL_FUNCTIONS)
|
|
257
|
+
|
|
258
|
+
_CURRENT_ACTION_NAMES: set[str] = set()
|
|
259
|
+
|
|
228
260
|
|
|
229
261
|
if TYPE_CHECKING:
|
|
230
262
|
from .workflow import Workflow
|
|
@@ -239,6 +271,32 @@ class ActionDefinition:
|
|
|
239
271
|
signature: inspect.Signature
|
|
240
272
|
|
|
241
273
|
|
|
274
|
+
@dataclass
|
|
275
|
+
class FunctionParameter:
|
|
276
|
+
"""A function parameter with optional default value."""
|
|
277
|
+
|
|
278
|
+
name: str
|
|
279
|
+
has_default: bool
|
|
280
|
+
default_value: Any = None # The actual Python value if has_default is True
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
@dataclass
|
|
284
|
+
class FunctionSignatureInfo:
|
|
285
|
+
"""Signature info for a workflow function, used to fill in default arguments."""
|
|
286
|
+
|
|
287
|
+
parameters: List[FunctionParameter]
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@dataclass
|
|
291
|
+
class ModuleContext:
|
|
292
|
+
"""Cached IRBuilder context derived from a module."""
|
|
293
|
+
|
|
294
|
+
action_defs: Dict[str, ActionDefinition]
|
|
295
|
+
imported_names: Dict[str, "ImportedName"]
|
|
296
|
+
module_functions: Set[str]
|
|
297
|
+
model_defs: Dict[str, "ModelDefinition"]
|
|
298
|
+
|
|
299
|
+
|
|
242
300
|
@dataclass
|
|
243
301
|
class TransformContext:
|
|
244
302
|
"""Context for IR transformations."""
|
|
@@ -247,10 +305,14 @@ class TransformContext:
|
|
|
247
305
|
implicit_fn_counter: int = 0
|
|
248
306
|
# Implicit functions generated during transformation
|
|
249
307
|
implicit_functions: List[ir.FunctionDef] = None # type: ignore
|
|
308
|
+
# Function signatures for workflow methods (used to fill in default arguments)
|
|
309
|
+
function_signatures: Dict[str, FunctionSignatureInfo] = None # type: ignore
|
|
250
310
|
|
|
251
311
|
def __post_init__(self) -> None:
|
|
252
312
|
if self.implicit_functions is None:
|
|
253
313
|
self.implicit_functions = []
|
|
314
|
+
if self.function_signatures is None:
|
|
315
|
+
self.function_signatures = {}
|
|
254
316
|
|
|
255
317
|
def next_implicit_fn_name(self, prefix: str = "implicit") -> str:
|
|
256
318
|
"""Generate a unique implicit function name."""
|
|
@@ -277,41 +339,218 @@ def build_workflow_ir(workflow_cls: type["Workflow"]) -> ir.Program:
|
|
|
277
339
|
if module is None:
|
|
278
340
|
raise ValueError(f"unable to locate module for workflow {workflow_cls!r}")
|
|
279
341
|
|
|
280
|
-
|
|
281
|
-
function_source = textwrap.dedent(inspect.getsource(original_run))
|
|
282
|
-
tree = ast.parse(function_source)
|
|
342
|
+
module_contexts: Dict[str, ModuleContext] = {}
|
|
283
343
|
|
|
284
|
-
|
|
285
|
-
|
|
344
|
+
def get_module_context(target_module: Any) -> ModuleContext:
|
|
345
|
+
module_name = target_module.__name__
|
|
346
|
+
if module_name not in module_contexts:
|
|
347
|
+
module_contexts[module_name] = ModuleContext(
|
|
348
|
+
action_defs=_discover_action_names(target_module),
|
|
349
|
+
imported_names=_discover_module_imports(target_module),
|
|
350
|
+
module_functions=_discover_module_functions(target_module),
|
|
351
|
+
model_defs=_discover_model_definitions(target_module),
|
|
352
|
+
)
|
|
353
|
+
return module_contexts[module_name]
|
|
286
354
|
|
|
287
|
-
#
|
|
288
|
-
|
|
355
|
+
# Build the IR with transformation context
|
|
356
|
+
ctx = TransformContext()
|
|
357
|
+
program = ir.Program()
|
|
358
|
+
function_defs: Dict[str, ir.FunctionDef] = {}
|
|
359
|
+
|
|
360
|
+
# Extract instance attributes from __init__ for policy resolution
|
|
361
|
+
instance_attrs = _extract_instance_attrs(workflow_cls)
|
|
362
|
+
|
|
363
|
+
def parse_function(fn: Any) -> tuple[ast.AST, Optional[str], int]:
|
|
364
|
+
source_lines, start_line = inspect.getsourcelines(fn)
|
|
365
|
+
function_source = textwrap.dedent("".join(source_lines))
|
|
366
|
+
filename = inspect.getsourcefile(fn)
|
|
367
|
+
if filename is None:
|
|
368
|
+
filename = inspect.getfile(fn)
|
|
369
|
+
return ast.parse(function_source, filename=filename or "<unknown>"), filename, start_line
|
|
370
|
+
|
|
371
|
+
def _with_source_location(
|
|
372
|
+
err: UnsupportedPatternError,
|
|
373
|
+
filename: Optional[str],
|
|
374
|
+
start_line: int,
|
|
375
|
+
) -> UnsupportedPatternError:
|
|
376
|
+
line = err.line
|
|
377
|
+
col = err.col
|
|
378
|
+
if line is not None:
|
|
379
|
+
line = start_line + line - 1
|
|
380
|
+
if col is not None:
|
|
381
|
+
col = col + 1
|
|
382
|
+
return UnsupportedPatternError(
|
|
383
|
+
err.message,
|
|
384
|
+
err.recommendation,
|
|
385
|
+
line=line,
|
|
386
|
+
col=col,
|
|
387
|
+
filename=filename,
|
|
388
|
+
)
|
|
289
389
|
|
|
290
|
-
|
|
291
|
-
|
|
390
|
+
def add_function_def(
|
|
391
|
+
fn: Any,
|
|
392
|
+
fn_tree: ast.AST,
|
|
393
|
+
filename: Optional[str],
|
|
394
|
+
start_line: int,
|
|
395
|
+
override_name: Optional[str] = None,
|
|
396
|
+
) -> None:
|
|
397
|
+
global _CURRENT_ACTION_NAMES
|
|
398
|
+
fn_module = inspect.getmodule(fn)
|
|
399
|
+
if fn_module is None:
|
|
400
|
+
raise ValueError(f"unable to locate module for function {fn!r}")
|
|
401
|
+
|
|
402
|
+
ctx_data = get_module_context(fn_module)
|
|
403
|
+
_CURRENT_ACTION_NAMES = set(ctx_data.action_defs.keys())
|
|
404
|
+
builder = IRBuilder(
|
|
405
|
+
ctx_data.action_defs,
|
|
406
|
+
ctx,
|
|
407
|
+
ctx_data.imported_names,
|
|
408
|
+
ctx_data.module_functions,
|
|
409
|
+
ctx_data.model_defs,
|
|
410
|
+
fn_module.__dict__,
|
|
411
|
+
instance_attrs,
|
|
412
|
+
)
|
|
413
|
+
try:
|
|
414
|
+
builder.visit(fn_tree)
|
|
415
|
+
except UnsupportedPatternError as err:
|
|
416
|
+
raise _with_source_location(err, filename, start_line) from err
|
|
417
|
+
if builder.function_def:
|
|
418
|
+
if override_name:
|
|
419
|
+
builder.function_def.name = override_name
|
|
420
|
+
function_defs[builder.function_def.name] = builder.function_def
|
|
421
|
+
|
|
422
|
+
# Discover all reachable helper methods first so we can pre-collect their signatures.
|
|
423
|
+
# This is needed because when we process a function that calls another function,
|
|
424
|
+
# we need to know the callee's signature to fill in default arguments.
|
|
425
|
+
run_tree, run_filename, run_start_line = parse_function(original_run)
|
|
426
|
+
|
|
427
|
+
# Collect all reachable methods and their trees
|
|
428
|
+
methods_to_process: List[tuple[Any, ast.AST, Optional[str], int, Optional[str]]] = [
|
|
429
|
+
(original_run, run_tree, run_filename, run_start_line, "main")
|
|
430
|
+
]
|
|
431
|
+
pending = list(_collect_self_method_calls(run_tree))
|
|
432
|
+
visited: Set[str] = set()
|
|
433
|
+
skip_methods = {"run_action"}
|
|
434
|
+
|
|
435
|
+
while pending:
|
|
436
|
+
method_name = pending.pop()
|
|
437
|
+
if method_name in visited or method_name == "run" or method_name in skip_methods:
|
|
438
|
+
continue
|
|
439
|
+
visited.add(method_name)
|
|
292
440
|
|
|
293
|
-
|
|
294
|
-
|
|
441
|
+
method = _find_workflow_method(workflow_cls, method_name)
|
|
442
|
+
if method is None:
|
|
443
|
+
continue
|
|
295
444
|
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
445
|
+
method_tree, method_filename, method_start_line = parse_function(method)
|
|
446
|
+
methods_to_process.append((method, method_tree, method_filename, method_start_line, None))
|
|
447
|
+
pending.extend(_collect_self_method_calls(method_tree))
|
|
448
|
+
|
|
449
|
+
# Pre-collect signatures for all methods before processing IR.
|
|
450
|
+
# This ensures that when we encounter a function call, we can fill in default args.
|
|
451
|
+
for fn, _fn_tree, _filename, _start_line, override_name in methods_to_process:
|
|
452
|
+
fn_name = override_name if override_name else fn.__name__
|
|
453
|
+
sig = inspect.signature(fn)
|
|
454
|
+
params: List[FunctionParameter] = []
|
|
455
|
+
for param_name, param in sig.parameters.items():
|
|
456
|
+
if param_name == "self":
|
|
457
|
+
continue
|
|
458
|
+
has_default = param.default is not inspect.Parameter.empty
|
|
459
|
+
default_value = param.default if has_default else None
|
|
460
|
+
params.append(
|
|
461
|
+
FunctionParameter(
|
|
462
|
+
name=param_name,
|
|
463
|
+
has_default=has_default,
|
|
464
|
+
default_value=default_value,
|
|
465
|
+
)
|
|
466
|
+
)
|
|
467
|
+
ctx.function_signatures[fn_name] = FunctionSignatureInfo(parameters=params)
|
|
300
468
|
|
|
301
|
-
#
|
|
302
|
-
|
|
469
|
+
# Now process all functions with signatures already available
|
|
470
|
+
for fn, fn_tree, filename, start_line, override_name in methods_to_process:
|
|
471
|
+
add_function_def(fn, fn_tree, filename, start_line, override_name)
|
|
303
472
|
|
|
304
473
|
# Add implicit functions first (they may be called by the main function)
|
|
305
474
|
for implicit_fn in ctx.implicit_functions:
|
|
306
475
|
program.functions.append(implicit_fn)
|
|
307
476
|
|
|
308
|
-
# Add
|
|
309
|
-
|
|
310
|
-
program.functions.append(
|
|
477
|
+
# Add all function definitions (run + reachable helper methods)
|
|
478
|
+
for fn_def in function_defs.values():
|
|
479
|
+
program.functions.append(fn_def)
|
|
480
|
+
|
|
481
|
+
global _CURRENT_ACTION_NAMES
|
|
482
|
+
_CURRENT_ACTION_NAMES = set()
|
|
311
483
|
|
|
312
484
|
return program
|
|
313
485
|
|
|
314
486
|
|
|
487
|
+
def _collect_self_method_calls(tree: ast.AST) -> Set[str]:
|
|
488
|
+
"""Collect self.method(...) call names from a parsed function AST."""
|
|
489
|
+
calls: Set[str] = set()
|
|
490
|
+
for node in ast.walk(tree):
|
|
491
|
+
if not isinstance(node, ast.Call):
|
|
492
|
+
continue
|
|
493
|
+
func = node.func
|
|
494
|
+
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
|
|
495
|
+
if func.value.id == "self":
|
|
496
|
+
calls.add(func.attr)
|
|
497
|
+
return calls
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _find_workflow_method(workflow_cls: type["Workflow"], name: str) -> Optional[Any]:
|
|
501
|
+
"""Find a workflow method by name across the class MRO."""
|
|
502
|
+
for base in workflow_cls.__mro__:
|
|
503
|
+
if name not in base.__dict__:
|
|
504
|
+
continue
|
|
505
|
+
value = base.__dict__[name]
|
|
506
|
+
if isinstance(value, staticmethod) or isinstance(value, classmethod):
|
|
507
|
+
return value.__func__
|
|
508
|
+
if inspect.isfunction(value):
|
|
509
|
+
return value
|
|
510
|
+
return None
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def _extract_instance_attrs(workflow_cls: type["Workflow"]) -> Dict[str, ast.expr]:
|
|
514
|
+
"""Extract self.attr = value assignments from the workflow's __init__ method.
|
|
515
|
+
|
|
516
|
+
Parses the __init__ method to find assignments like:
|
|
517
|
+
self.retry_policy = RetryPolicy(attempts=3)
|
|
518
|
+
self.timeout = 30
|
|
519
|
+
|
|
520
|
+
Returns a dict mapping attribute names to their AST value nodes.
|
|
521
|
+
"""
|
|
522
|
+
init_method = _find_workflow_method(workflow_cls, "__init__")
|
|
523
|
+
if init_method is None:
|
|
524
|
+
return {}
|
|
525
|
+
|
|
526
|
+
try:
|
|
527
|
+
source_lines, _ = inspect.getsourcelines(init_method)
|
|
528
|
+
source = textwrap.dedent("".join(source_lines))
|
|
529
|
+
tree = ast.parse(source)
|
|
530
|
+
except (OSError, TypeError, SyntaxError):
|
|
531
|
+
return {}
|
|
532
|
+
|
|
533
|
+
attrs: Dict[str, ast.expr] = {}
|
|
534
|
+
|
|
535
|
+
# Walk the __init__ body looking for self.attr = value assignments
|
|
536
|
+
for node in ast.walk(tree):
|
|
537
|
+
if not isinstance(node, ast.Assign):
|
|
538
|
+
continue
|
|
539
|
+
# Only handle single-target assignments
|
|
540
|
+
if len(node.targets) != 1:
|
|
541
|
+
continue
|
|
542
|
+
target = node.targets[0]
|
|
543
|
+
# Check for self.attr pattern
|
|
544
|
+
if (
|
|
545
|
+
isinstance(target, ast.Attribute)
|
|
546
|
+
and isinstance(target.value, ast.Name)
|
|
547
|
+
and target.value.id == "self"
|
|
548
|
+
):
|
|
549
|
+
attrs[target.attr] = node.value
|
|
550
|
+
|
|
551
|
+
return attrs
|
|
552
|
+
|
|
553
|
+
|
|
315
554
|
def _discover_action_names(module: Any) -> Dict[str, ActionDefinition]:
|
|
316
555
|
"""Discover all @action decorated functions in a module."""
|
|
317
556
|
names: Dict[str, ActionDefinition] = {}
|
|
@@ -326,6 +565,18 @@ def _discover_action_names(module: Any) -> Dict[str, ActionDefinition]:
|
|
|
326
565
|
module_name=action_module or module.__name__,
|
|
327
566
|
signature=signature,
|
|
328
567
|
)
|
|
568
|
+
for entry in registry.entries():
|
|
569
|
+
if entry.module != module.__name__:
|
|
570
|
+
continue
|
|
571
|
+
func_name = entry.func.__name__
|
|
572
|
+
if func_name in names:
|
|
573
|
+
continue
|
|
574
|
+
signature = inspect.signature(entry.func)
|
|
575
|
+
names[func_name] = ActionDefinition(
|
|
576
|
+
action_name=entry.name,
|
|
577
|
+
module_name=entry.module,
|
|
578
|
+
signature=signature,
|
|
579
|
+
)
|
|
329
580
|
return names
|
|
330
581
|
|
|
331
582
|
|
|
@@ -592,21 +843,22 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
592
843
|
imported_names: Optional[Dict[str, ImportedName]] = None,
|
|
593
844
|
module_functions: Optional[Set[str]] = None,
|
|
594
845
|
model_defs: Optional[Dict[str, ModelDefinition]] = None,
|
|
846
|
+
module_globals: Optional[Mapping[str, Any]] = None,
|
|
847
|
+
instance_attrs: Optional[Dict[str, ast.expr]] = None,
|
|
595
848
|
):
|
|
596
849
|
self._action_defs = action_defs
|
|
597
850
|
self._ctx = ctx
|
|
598
851
|
self._imported_names = imported_names or {}
|
|
599
852
|
self._module_functions = module_functions or set()
|
|
600
853
|
self._model_defs = model_defs or {}
|
|
854
|
+
self._module_globals = module_globals or {}
|
|
855
|
+
self._instance_attrs = instance_attrs or {}
|
|
601
856
|
self.function_def: Optional[ir.FunctionDef] = None
|
|
602
857
|
self._statements: List[ir.Statement] = []
|
|
603
858
|
|
|
604
859
|
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
|
605
860
|
"""Visit a function definition (the workflow's run method)."""
|
|
606
|
-
|
|
607
|
-
inputs: List[str] = []
|
|
608
|
-
for arg in node.args.args[1:]: # Skip 'self'
|
|
609
|
-
inputs.append(arg.arg)
|
|
861
|
+
inputs = self._collect_function_inputs(node)
|
|
610
862
|
|
|
611
863
|
# Create the function definition
|
|
612
864
|
self.function_def = ir.FunctionDef(
|
|
@@ -627,9 +879,7 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
627
879
|
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
|
628
880
|
"""Visit an async function definition (the workflow's run method)."""
|
|
629
881
|
# Handle async the same way as sync for IR building
|
|
630
|
-
inputs
|
|
631
|
-
for arg in node.args.args[1:]: # Skip 'self'
|
|
632
|
-
inputs.append(arg.arg)
|
|
882
|
+
inputs = self._collect_function_inputs(node)
|
|
633
883
|
|
|
634
884
|
self.function_def = ir.FunctionDef(
|
|
635
885
|
name=node.name,
|
|
@@ -661,32 +911,42 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
661
911
|
return expanded
|
|
662
912
|
result = self._visit_assign(node)
|
|
663
913
|
return [result] if result else []
|
|
914
|
+
elif isinstance(node, ast.AnnAssign):
|
|
915
|
+
result = self._visit_ann_assign(node)
|
|
916
|
+
return [result] if result else []
|
|
664
917
|
elif isinstance(node, ast.Expr):
|
|
665
918
|
result = self._visit_expr_stmt(node)
|
|
666
919
|
return [result] if result else []
|
|
667
920
|
elif isinstance(node, ast.For):
|
|
668
921
|
return self._visit_for(node)
|
|
669
922
|
elif isinstance(node, ast.If):
|
|
670
|
-
|
|
671
|
-
return [result] if result else []
|
|
923
|
+
return self._visit_if(node)
|
|
672
924
|
elif isinstance(node, ast.Try):
|
|
673
925
|
return self._visit_try(node)
|
|
674
926
|
elif isinstance(node, ast.Return):
|
|
675
927
|
return self._visit_return(node)
|
|
676
928
|
elif isinstance(node, ast.AugAssign):
|
|
677
|
-
|
|
678
|
-
return [result] if result else []
|
|
929
|
+
return self._visit_aug_assign(node)
|
|
679
930
|
elif isinstance(node, ast.Pass):
|
|
680
931
|
# Pass statements are fine, they just don't produce IR
|
|
681
932
|
return []
|
|
933
|
+
elif isinstance(node, ast.Break):
|
|
934
|
+
return self._visit_break(node)
|
|
935
|
+
elif isinstance(node, ast.Continue):
|
|
936
|
+
return self._visit_continue(node)
|
|
682
937
|
|
|
683
|
-
# Check for unsupported statement types
|
|
938
|
+
# Check for unsupported statement types - this MUST raise for any
|
|
939
|
+
# unhandled statement to avoid silently dropping code
|
|
684
940
|
self._check_unsupported_statement(node)
|
|
685
941
|
|
|
686
|
-
|
|
942
|
+
def _check_unsupported_statement(self, node: ast.stmt) -> NoReturn:
|
|
943
|
+
"""Check for unsupported statement types and raise descriptive errors.
|
|
687
944
|
|
|
688
|
-
|
|
689
|
-
|
|
945
|
+
This function ALWAYS raises an exception - it never returns normally.
|
|
946
|
+
Any statement type that reaches this function is either explicitly
|
|
947
|
+
unsupported (with a specific error message) or unhandled (with a
|
|
948
|
+
generic catch-all error). This ensures we never silently drop code.
|
|
949
|
+
"""
|
|
690
950
|
line = getattr(node, "lineno", None)
|
|
691
951
|
col = getattr(node, "col_offset", None)
|
|
692
952
|
|
|
@@ -767,6 +1027,16 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
767
1027
|
line=line,
|
|
768
1028
|
col=col,
|
|
769
1029
|
)
|
|
1030
|
+
else:
|
|
1031
|
+
# Catch-all for any unhandled statement types.
|
|
1032
|
+
# This is critical to avoid silently dropping code.
|
|
1033
|
+
stmt_type = type(node).__name__
|
|
1034
|
+
raise UnsupportedPatternError(
|
|
1035
|
+
f"Unhandled statement type: {stmt_type}",
|
|
1036
|
+
RECOMMENDATIONS["unsupported_statement"],
|
|
1037
|
+
line=line,
|
|
1038
|
+
col=col,
|
|
1039
|
+
)
|
|
770
1040
|
|
|
771
1041
|
def _expand_list_comprehension_assignment(
|
|
772
1042
|
self, node: ast.Assign
|
|
@@ -985,6 +1255,16 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
985
1255
|
|
|
986
1256
|
return statements
|
|
987
1257
|
|
|
1258
|
+
def _visit_ann_assign(self, node: ast.AnnAssign) -> Optional[ir.Statement]:
|
|
1259
|
+
"""Convert annotated assignment to IR when a value is present."""
|
|
1260
|
+
if node.value is None:
|
|
1261
|
+
return None
|
|
1262
|
+
|
|
1263
|
+
assign = ast.Assign(targets=[node.target], value=node.value, type_comment=None)
|
|
1264
|
+
ast.copy_location(assign, node)
|
|
1265
|
+
ast.fix_missing_locations(assign)
|
|
1266
|
+
return self._visit_assign(assign)
|
|
1267
|
+
|
|
988
1268
|
def _visit_assign(self, node: ast.Assign) -> Optional[ir.Statement]:
|
|
989
1269
|
"""Convert assignment to IR.
|
|
990
1270
|
|
|
@@ -1038,7 +1318,7 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1038
1318
|
return stmt
|
|
1039
1319
|
|
|
1040
1320
|
# Regular assignment (variables, literals, expressions)
|
|
1041
|
-
value_expr =
|
|
1321
|
+
value_expr = self._expr_to_ir_with_model_coercion(node.value)
|
|
1042
1322
|
if value_expr:
|
|
1043
1323
|
assign = ir.Assignment(targets=targets, value=value_expr)
|
|
1044
1324
|
stmt.assignment.CopyFrom(assign)
|
|
@@ -1090,7 +1370,7 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1090
1370
|
append_value = call.args[0]
|
|
1091
1371
|
# Create: list = list + [value]
|
|
1092
1372
|
list_var = ir.Expr(variable=ir.Variable(name=list_name), span=_make_span(node))
|
|
1093
|
-
value_expr =
|
|
1373
|
+
value_expr = self._expr_to_ir_with_model_coercion(append_value)
|
|
1094
1374
|
if value_expr:
|
|
1095
1375
|
# Create [value] as a list literal
|
|
1096
1376
|
list_literal = ir.Expr(
|
|
@@ -1108,7 +1388,7 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1108
1388
|
return stmt
|
|
1109
1389
|
|
|
1110
1390
|
# Regular expression
|
|
1111
|
-
expr =
|
|
1391
|
+
expr = self._expr_to_ir_with_model_coercion(node.value)
|
|
1112
1392
|
if expr:
|
|
1113
1393
|
stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=expr))
|
|
1114
1394
|
return stmt
|
|
@@ -1116,34 +1396,10 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1116
1396
|
return None
|
|
1117
1397
|
|
|
1118
1398
|
def _visit_for(self, node: ast.For) -> List[ir.Statement]:
|
|
1119
|
-
"""Convert for loop to IR
|
|
1120
|
-
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
For loops that modify out-of-scope variables (accumulators) are detected
|
|
1125
|
-
and those variables are set as targets on the SingleCallBody. This enables
|
|
1126
|
-
the runtime to properly aggregate results into those variables.
|
|
1127
|
-
|
|
1128
|
-
Supported accumulator patterns:
|
|
1129
|
-
1. List append: results.append(value)
|
|
1130
|
-
2. Dict subscript: result[key] = value
|
|
1131
|
-
3. List concatenation: results = results + [value]
|
|
1132
|
-
4. Counter increment: count = count + 1
|
|
1133
|
-
|
|
1134
|
-
Python:
|
|
1135
|
-
for item in items:
|
|
1136
|
-
a = await step_one(item)
|
|
1137
|
-
b = await step_two(a)
|
|
1138
|
-
|
|
1139
|
-
Becomes IR equivalent of:
|
|
1140
|
-
fn __for_body_1__(item):
|
|
1141
|
-
a = @step_one(item=item)
|
|
1142
|
-
b = @step_two(a=a)
|
|
1143
|
-
return b
|
|
1144
|
-
|
|
1145
|
-
for item in items:
|
|
1146
|
-
__for_body_1__(item=item)
|
|
1399
|
+
"""Convert for loop to IR.
|
|
1400
|
+
|
|
1401
|
+
The loop body is emitted as a full block so it can contain multiple
|
|
1402
|
+
statements/calls and early `return`.
|
|
1147
1403
|
"""
|
|
1148
1404
|
# Get loop variables
|
|
1149
1405
|
loop_vars: List[str] = []
|
|
@@ -1155,41 +1411,21 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1155
1411
|
loop_vars.append(elt.id)
|
|
1156
1412
|
|
|
1157
1413
|
# Get iterable
|
|
1158
|
-
iterable =
|
|
1414
|
+
iterable = self._expr_to_ir_with_model_coercion(node.iter)
|
|
1159
1415
|
if not iterable:
|
|
1160
1416
|
return []
|
|
1161
1417
|
|
|
1162
|
-
# Collect variables defined within the loop body (in-scope)
|
|
1163
|
-
in_scope_vars = set(loop_vars)
|
|
1164
|
-
|
|
1165
1418
|
# Build body statements (recursively transforms nested structures)
|
|
1166
1419
|
body_stmts: List[ir.Statement] = []
|
|
1167
1420
|
for body_node in node.body:
|
|
1168
1421
|
stmts = self._visit_statement(body_node)
|
|
1169
1422
|
body_stmts.extend(stmts)
|
|
1170
|
-
# Track variables defined by assignments in this iteration
|
|
1171
|
-
for s in stmts:
|
|
1172
|
-
if s.HasField("assignment"):
|
|
1173
|
-
in_scope_vars.update(s.assignment.targets)
|
|
1174
|
-
|
|
1175
|
-
# Detect all out-of-scope variable modifications
|
|
1176
|
-
# These are variables modified in the loop body but defined outside it
|
|
1177
|
-
modified_vars = self._detect_accumulator_targets(body_stmts, in_scope_vars)
|
|
1178
|
-
|
|
1179
|
-
# ALWAYS wrap for loop body into a synthetic function for variable isolation.
|
|
1180
|
-
# Variables flow in/out explicitly through function parameters and return values.
|
|
1181
|
-
body_stmts = self._wrap_body_as_function(
|
|
1182
|
-
body_stmts, "for_body", node, inputs=loop_vars, modified_vars=modified_vars
|
|
1183
|
-
)
|
|
1184
1423
|
|
|
1185
|
-
# Convert to SingleCallBody (now contains just the synthetic function call)
|
|
1186
1424
|
stmt = ir.Statement(span=_make_span(node))
|
|
1187
|
-
single_call_body = self._stmts_to_single_call_body(body_stmts, _make_span(node))
|
|
1188
|
-
|
|
1189
1425
|
for_loop = ir.ForLoop(
|
|
1190
1426
|
loop_vars=loop_vars,
|
|
1191
1427
|
iterable=iterable,
|
|
1192
|
-
|
|
1428
|
+
block_body=ir.Block(statements=body_stmts, span=_make_span(node)),
|
|
1193
1429
|
)
|
|
1194
1430
|
stmt.for_loop.CopyFrom(for_loop)
|
|
1195
1431
|
return [stmt]
|
|
@@ -1223,26 +1459,22 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1223
1459
|
# Check conditionals for accumulator targets in branch bodies
|
|
1224
1460
|
if stmt.HasField("conditional"):
|
|
1225
1461
|
cond = stmt.conditional
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1234
|
-
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
):
|
|
1243
|
-
if var not in seen:
|
|
1244
|
-
accumulators.append(var)
|
|
1245
|
-
seen.add(var)
|
|
1462
|
+
branch_blocks: list[ir.Block] = []
|
|
1463
|
+
if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
|
|
1464
|
+
branch_blocks.append(cond.if_branch.block_body)
|
|
1465
|
+
for branch in cond.elif_branches:
|
|
1466
|
+
if branch.HasField("block_body"):
|
|
1467
|
+
branch_blocks.append(branch.block_body)
|
|
1468
|
+
if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
|
|
1469
|
+
branch_blocks.append(cond.else_branch.block_body)
|
|
1470
|
+
|
|
1471
|
+
for block in branch_blocks:
|
|
1472
|
+
for var in self._detect_accumulator_targets(
|
|
1473
|
+
list(block.statements), in_scope_vars
|
|
1474
|
+
):
|
|
1475
|
+
if var not in seen:
|
|
1476
|
+
accumulators.append(var)
|
|
1477
|
+
seen.add(var)
|
|
1246
1478
|
|
|
1247
1479
|
return accumulators
|
|
1248
1480
|
|
|
@@ -1316,10 +1548,9 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1316
1548
|
for elem in expr.list.elements:
|
|
1317
1549
|
vars_found.update(self._collect_variables_from_expr(elem))
|
|
1318
1550
|
elif expr.HasField("dict"):
|
|
1319
|
-
for
|
|
1320
|
-
vars_found.update(self._collect_variables_from_expr(key))
|
|
1321
|
-
|
|
1322
|
-
vars_found.update(self._collect_variables_from_expr(val))
|
|
1551
|
+
for entry in expr.dict.entries:
|
|
1552
|
+
vars_found.update(self._collect_variables_from_expr(entry.key))
|
|
1553
|
+
vars_found.update(self._collect_variables_from_expr(entry.value))
|
|
1323
1554
|
elif expr.HasField("index"):
|
|
1324
1555
|
vars_found.update(self._collect_variables_from_expr(expr.index.value))
|
|
1325
1556
|
vars_found.update(self._collect_variables_from_expr(expr.index.index))
|
|
@@ -1334,112 +1565,143 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1334
1565
|
|
|
1335
1566
|
return vars_found
|
|
1336
1567
|
|
|
1337
|
-
def _visit_if(self, node: ast.If) ->
|
|
1338
|
-
"""Convert if statement to IR
|
|
1568
|
+
def _visit_if(self, node: ast.If) -> List[ir.Statement]:
|
|
1569
|
+
"""Convert if statement to IR.
|
|
1339
1570
|
|
|
1340
|
-
|
|
1341
|
-
|
|
1571
|
+
Normalizes patterns like:
|
|
1572
|
+
if await some_action(...):
|
|
1573
|
+
...
|
|
1574
|
+
into:
|
|
1575
|
+
__if_cond_n__ = await some_action(...)
|
|
1576
|
+
if __if_cond_n__:
|
|
1577
|
+
...
|
|
1578
|
+
"""
|
|
1342
1579
|
|
|
1343
|
-
|
|
1344
|
-
|
|
1580
|
+
def normalize_condition(test: ast.expr) -> tuple[List[ir.Statement], Optional[ir.Expr]]:
|
|
1581
|
+
action_call = self._extract_action_call(test)
|
|
1582
|
+
if action_call is None:
|
|
1583
|
+
return ([], self._expr_to_ir_with_model_coercion(test))
|
|
1345
1584
|
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1585
|
+
if not isinstance(test, ast.Await):
|
|
1586
|
+
line = getattr(test, "lineno", None)
|
|
1587
|
+
col = getattr(test, "col_offset", None)
|
|
1588
|
+
raise UnsupportedPatternError(
|
|
1589
|
+
"Action calls inside boolean expressions are not supported in if conditions",
|
|
1590
|
+
"Assign the awaited action result to a variable, then use the variable in the if condition.",
|
|
1591
|
+
line=line,
|
|
1592
|
+
col=col,
|
|
1593
|
+
)
|
|
1352
1594
|
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
1356
|
-
|
|
1357
|
-
|
|
1595
|
+
cond_var = self._ctx.next_implicit_fn_name(prefix="if_cond")
|
|
1596
|
+
assign_stmt = ir.Statement(span=_make_span(test))
|
|
1597
|
+
assign_stmt.assignment.CopyFrom(
|
|
1598
|
+
ir.Assignment(
|
|
1599
|
+
targets=[cond_var],
|
|
1600
|
+
value=ir.Expr(action_call=action_call, span=_make_span(test)),
|
|
1601
|
+
)
|
|
1602
|
+
)
|
|
1603
|
+
cond_expr = ir.Expr(variable=ir.Variable(name=cond_var), span=_make_span(test))
|
|
1604
|
+
return ([assign_stmt], cond_expr)
|
|
1358
1605
|
|
|
1359
|
-
|
|
1360
|
-
|
|
1361
|
-
|
|
1362
|
-
|
|
1363
|
-
|
|
1364
|
-
stmt = ir.Statement(span=_make_span(node))
|
|
1606
|
+
def visit_body(nodes: list[ast.stmt]) -> List[ir.Statement]:
|
|
1607
|
+
stmts: List[ir.Statement] = []
|
|
1608
|
+
for body_node in nodes:
|
|
1609
|
+
stmts.extend(self._visit_statement(body_node))
|
|
1610
|
+
return stmts
|
|
1365
1611
|
|
|
1366
|
-
#
|
|
1367
|
-
|
|
1368
|
-
|
|
1369
|
-
|
|
1612
|
+
# Collect if/elif branches as (test_expr, body_nodes)
|
|
1613
|
+
branches: list[tuple[ast.expr, list[ast.stmt], ast.AST]] = [(node.test, node.body, node)]
|
|
1614
|
+
current = node
|
|
1615
|
+
while current.orelse and len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If):
|
|
1616
|
+
elif_node = current.orelse[0]
|
|
1617
|
+
branches.append((elif_node.test, elif_node.body, elif_node))
|
|
1618
|
+
current = elif_node
|
|
1619
|
+
|
|
1620
|
+
else_nodes = current.orelse
|
|
1621
|
+
|
|
1622
|
+
normalized: list[
|
|
1623
|
+
tuple[List[ir.Statement], Optional[ir.Expr], List[ir.Statement], ast.AST]
|
|
1624
|
+
] = []
|
|
1625
|
+
for test_expr, body_nodes, span_node in branches:
|
|
1626
|
+
prefix, cond = normalize_condition(test_expr)
|
|
1627
|
+
normalized.append((prefix, cond, visit_body(body_nodes), span_node))
|
|
1628
|
+
|
|
1629
|
+
else_body = visit_body(else_nodes) if else_nodes else []
|
|
1630
|
+
|
|
1631
|
+
# If any non-first branch needs normalization, preserve Python semantics by nesting.
|
|
1632
|
+
requires_nested = any(prefix for prefix, _, _, _ in normalized[1:])
|
|
1633
|
+
|
|
1634
|
+
def build_conditional_stmt(
|
|
1635
|
+
condition: ir.Expr,
|
|
1636
|
+
then_body: List[ir.Statement],
|
|
1637
|
+
else_body_statements: List[ir.Statement],
|
|
1638
|
+
span_node: ast.AST,
|
|
1639
|
+
) -> ir.Statement:
|
|
1640
|
+
conditional_stmt = ir.Statement(span=_make_span(span_node))
|
|
1641
|
+
if_branch = ir.IfBranch(
|
|
1642
|
+
condition=condition,
|
|
1643
|
+
block_body=ir.Block(statements=then_body, span=_make_span(span_node)),
|
|
1644
|
+
span=_make_span(span_node),
|
|
1645
|
+
)
|
|
1646
|
+
conditional = ir.Conditional(if_branch=if_branch)
|
|
1647
|
+
if else_body_statements:
|
|
1648
|
+
else_branch = ir.ElseBranch(
|
|
1649
|
+
block_body=ir.Block(
|
|
1650
|
+
statements=else_body_statements,
|
|
1651
|
+
span=_make_span(span_node),
|
|
1652
|
+
),
|
|
1653
|
+
span=_make_span(span_node),
|
|
1654
|
+
)
|
|
1655
|
+
conditional.else_branch.CopyFrom(else_branch)
|
|
1656
|
+
conditional_stmt.conditional.CopyFrom(conditional)
|
|
1657
|
+
return conditional_stmt
|
|
1370
1658
|
|
|
1371
|
-
|
|
1372
|
-
|
|
1373
|
-
|
|
1374
|
-
|
|
1659
|
+
if requires_nested:
|
|
1660
|
+
nested_else: List[ir.Statement] = else_body
|
|
1661
|
+
for prefix, cond, then_body, span_node in reversed(normalized):
|
|
1662
|
+
if cond is None:
|
|
1663
|
+
continue
|
|
1664
|
+
nested_if_stmt = build_conditional_stmt(
|
|
1665
|
+
condition=cond,
|
|
1666
|
+
then_body=then_body,
|
|
1667
|
+
else_body_statements=nested_else,
|
|
1668
|
+
span_node=span_node,
|
|
1669
|
+
)
|
|
1670
|
+
nested_else = [*prefix, nested_if_stmt]
|
|
1671
|
+
return nested_else
|
|
1375
1672
|
|
|
1376
|
-
#
|
|
1377
|
-
|
|
1378
|
-
|
|
1379
|
-
|
|
1380
|
-
body_stmts, "if_then", node, modified_vars=modified_vars
|
|
1381
|
-
)
|
|
1673
|
+
# Flat conditional with elif/else (original behavior), plus optional prefix for the if guard.
|
|
1674
|
+
if_prefix, if_condition, if_body, if_span_node = normalized[0]
|
|
1675
|
+
if if_condition is None:
|
|
1676
|
+
return []
|
|
1382
1677
|
|
|
1678
|
+
conditional_stmt = ir.Statement(span=_make_span(if_span_node))
|
|
1383
1679
|
if_branch = ir.IfBranch(
|
|
1384
|
-
condition=
|
|
1385
|
-
|
|
1386
|
-
span=_make_span(
|
|
1680
|
+
condition=if_condition,
|
|
1681
|
+
block_body=ir.Block(statements=if_body, span=_make_span(if_span_node)),
|
|
1682
|
+
span=_make_span(if_span_node),
|
|
1387
1683
|
)
|
|
1388
|
-
|
|
1389
1684
|
conditional = ir.Conditional(if_branch=if_branch)
|
|
1390
1685
|
|
|
1391
|
-
|
|
1392
|
-
|
|
1393
|
-
|
|
1394
|
-
|
|
1395
|
-
|
|
1396
|
-
|
|
1397
|
-
|
|
1398
|
-
|
|
1399
|
-
|
|
1400
|
-
for body_node in elif_node.body:
|
|
1401
|
-
stmts = self._visit_statement(body_node)
|
|
1402
|
-
elif_body.extend(stmts)
|
|
1403
|
-
|
|
1404
|
-
# ALWAYS wrap elif body for variable isolation
|
|
1405
|
-
in_scope_vars = self._collect_assigned_vars(elif_body)
|
|
1406
|
-
modified_vars = self._detect_accumulator_targets(elif_body, in_scope_vars)
|
|
1407
|
-
elif_body = self._wrap_body_as_function(
|
|
1408
|
-
elif_body, "if_elif", elif_node, modified_vars=modified_vars
|
|
1409
|
-
)
|
|
1410
|
-
|
|
1411
|
-
elif_branch = ir.ElifBranch(
|
|
1412
|
-
condition=elif_condition,
|
|
1413
|
-
body=self._stmts_to_single_call_body(elif_body, _make_span(elif_node)),
|
|
1414
|
-
span=_make_span(elif_node),
|
|
1415
|
-
)
|
|
1416
|
-
conditional.elif_branches.append(elif_branch)
|
|
1417
|
-
current = elif_node
|
|
1418
|
-
else:
|
|
1419
|
-
# else branch
|
|
1420
|
-
else_body: List[ir.Statement] = []
|
|
1421
|
-
for else_node in current.orelse:
|
|
1422
|
-
stmts = self._visit_statement(else_node)
|
|
1423
|
-
else_body.extend(stmts)
|
|
1424
|
-
|
|
1425
|
-
# ALWAYS wrap else body for variable isolation
|
|
1426
|
-
in_scope_vars = self._collect_assigned_vars(else_body)
|
|
1427
|
-
modified_vars = self._detect_accumulator_targets(else_body, in_scope_vars)
|
|
1428
|
-
else_body = self._wrap_body_as_function(
|
|
1429
|
-
else_body, "if_else", current.orelse[0], modified_vars=modified_vars
|
|
1430
|
-
)
|
|
1686
|
+
for _, elif_condition, elif_body, elif_span_node in normalized[1:]:
|
|
1687
|
+
if elif_condition is None:
|
|
1688
|
+
continue
|
|
1689
|
+
elif_branch = ir.ElifBranch(
|
|
1690
|
+
condition=elif_condition,
|
|
1691
|
+
block_body=ir.Block(statements=elif_body, span=_make_span(elif_span_node)),
|
|
1692
|
+
span=_make_span(elif_span_node),
|
|
1693
|
+
)
|
|
1694
|
+
conditional.elif_branches.append(elif_branch)
|
|
1431
1695
|
|
|
1432
|
-
|
|
1433
|
-
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1437
|
-
|
|
1438
|
-
conditional.else_branch.CopyFrom(else_branch)
|
|
1439
|
-
break
|
|
1696
|
+
if else_body:
|
|
1697
|
+
else_branch = ir.ElseBranch(
|
|
1698
|
+
block_body=ir.Block(statements=else_body, span=_make_span(if_span_node)),
|
|
1699
|
+
span=_make_span(if_span_node),
|
|
1700
|
+
)
|
|
1701
|
+
conditional.else_branch.CopyFrom(else_branch)
|
|
1440
1702
|
|
|
1441
|
-
|
|
1442
|
-
return
|
|
1703
|
+
conditional_stmt.conditional.CopyFrom(conditional)
|
|
1704
|
+
return [*if_prefix, conditional_stmt]
|
|
1443
1705
|
|
|
1444
1706
|
def _collect_assigned_vars(self, stmts: List[ir.Statement]) -> set:
|
|
1445
1707
|
"""Collect all variable names assigned in a list of statements."""
|
|
@@ -1463,48 +1725,48 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1463
1725
|
|
|
1464
1726
|
if stmt.HasField("conditional"):
|
|
1465
1727
|
cond = stmt.conditional
|
|
1466
|
-
if cond.HasField("if_branch") and cond.if_branch.HasField("
|
|
1728
|
+
if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
|
|
1467
1729
|
for target in self._collect_assigned_vars_in_order(
|
|
1468
|
-
list(cond.if_branch.
|
|
1730
|
+
list(cond.if_branch.block_body.statements)
|
|
1469
1731
|
):
|
|
1470
1732
|
if target not in seen:
|
|
1471
1733
|
seen.add(target)
|
|
1472
1734
|
assigned.append(target)
|
|
1473
1735
|
for elif_branch in cond.elif_branches:
|
|
1474
|
-
if elif_branch.HasField("
|
|
1736
|
+
if elif_branch.HasField("block_body"):
|
|
1475
1737
|
for target in self._collect_assigned_vars_in_order(
|
|
1476
|
-
list(elif_branch.
|
|
1738
|
+
list(elif_branch.block_body.statements)
|
|
1477
1739
|
):
|
|
1478
1740
|
if target not in seen:
|
|
1479
1741
|
seen.add(target)
|
|
1480
1742
|
assigned.append(target)
|
|
1481
|
-
if cond.HasField("else_branch") and cond.else_branch.HasField("
|
|
1743
|
+
if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
|
|
1482
1744
|
for target in self._collect_assigned_vars_in_order(
|
|
1483
|
-
list(cond.else_branch.
|
|
1745
|
+
list(cond.else_branch.block_body.statements)
|
|
1484
1746
|
):
|
|
1485
1747
|
if target not in seen:
|
|
1486
1748
|
seen.add(target)
|
|
1487
1749
|
assigned.append(target)
|
|
1488
1750
|
|
|
1489
|
-
if stmt.HasField("for_loop") and stmt.for_loop.HasField("
|
|
1751
|
+
if stmt.HasField("for_loop") and stmt.for_loop.HasField("block_body"):
|
|
1490
1752
|
for target in self._collect_assigned_vars_in_order(
|
|
1491
|
-
list(stmt.for_loop.
|
|
1753
|
+
list(stmt.for_loop.block_body.statements)
|
|
1492
1754
|
):
|
|
1493
1755
|
if target not in seen:
|
|
1494
1756
|
seen.add(target)
|
|
1495
1757
|
assigned.append(target)
|
|
1496
1758
|
|
|
1497
1759
|
if stmt.HasField("try_except"):
|
|
1498
|
-
|
|
1499
|
-
if
|
|
1500
|
-
for target in self._collect_assigned_vars_in_order(list(
|
|
1760
|
+
try_block = stmt.try_except.try_block
|
|
1761
|
+
if try_block.HasField("span"):
|
|
1762
|
+
for target in self._collect_assigned_vars_in_order(list(try_block.statements)):
|
|
1501
1763
|
if target not in seen:
|
|
1502
1764
|
seen.add(target)
|
|
1503
1765
|
assigned.append(target)
|
|
1504
1766
|
for handler in stmt.try_except.handlers:
|
|
1505
|
-
if handler.HasField("
|
|
1767
|
+
if handler.HasField("block_body"):
|
|
1506
1768
|
for target in self._collect_assigned_vars_in_order(
|
|
1507
|
-
list(handler.
|
|
1769
|
+
list(handler.block_body.statements)
|
|
1508
1770
|
):
|
|
1509
1771
|
if target not in seen:
|
|
1510
1772
|
seen.add(target)
|
|
@@ -1512,32 +1774,8 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1512
1774
|
|
|
1513
1775
|
return assigned
|
|
1514
1776
|
|
|
1515
|
-
def
|
|
1516
|
-
|
|
1517
|
-
seen: set[str] = set()
|
|
1518
|
-
|
|
1519
|
-
if body.HasField("call"):
|
|
1520
|
-
call = body.call
|
|
1521
|
-
if call.HasField("action"):
|
|
1522
|
-
for kwarg in call.action.kwargs:
|
|
1523
|
-
for var in self._collect_variables_from_expr(kwarg.value):
|
|
1524
|
-
if var not in seen:
|
|
1525
|
-
seen.add(var)
|
|
1526
|
-
vars_found.append(var)
|
|
1527
|
-
elif call.HasField("function"):
|
|
1528
|
-
for kwarg in call.function.kwargs:
|
|
1529
|
-
for var in self._collect_variables_from_expr(kwarg.value):
|
|
1530
|
-
if var not in seen:
|
|
1531
|
-
seen.add(var)
|
|
1532
|
-
vars_found.append(var)
|
|
1533
|
-
|
|
1534
|
-
for stmt in body.statements:
|
|
1535
|
-
for var in self._collect_variables_from_statements([stmt]):
|
|
1536
|
-
if var not in seen:
|
|
1537
|
-
seen.add(var)
|
|
1538
|
-
vars_found.append(var)
|
|
1539
|
-
|
|
1540
|
-
return vars_found
|
|
1777
|
+
def _collect_variables_from_block(self, block: ir.Block) -> list[str]:
|
|
1778
|
+
return self._collect_variables_from_statements(list(block.statements))
|
|
1541
1779
|
|
|
1542
1780
|
def _collect_variables_from_statements(self, stmts: List[ir.Statement]) -> list[str]:
|
|
1543
1781
|
"""Collect variable references from statements in encounter order."""
|
|
@@ -1578,10 +1816,8 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1578
1816
|
if var not in seen:
|
|
1579
1817
|
seen.add(var)
|
|
1580
1818
|
vars_found.append(var)
|
|
1581
|
-
if cond.if_branch.HasField("
|
|
1582
|
-
for var in self.
|
|
1583
|
-
cond.if_branch.body
|
|
1584
|
-
):
|
|
1819
|
+
if cond.if_branch.HasField("block_body"):
|
|
1820
|
+
for var in self._collect_variables_from_block(cond.if_branch.block_body):
|
|
1585
1821
|
if var not in seen:
|
|
1586
1822
|
seen.add(var)
|
|
1587
1823
|
vars_found.append(var)
|
|
@@ -1591,13 +1827,13 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1591
1827
|
if var not in seen:
|
|
1592
1828
|
seen.add(var)
|
|
1593
1829
|
vars_found.append(var)
|
|
1594
|
-
if elif_branch.HasField("
|
|
1595
|
-
for var in self.
|
|
1830
|
+
if elif_branch.HasField("block_body"):
|
|
1831
|
+
for var in self._collect_variables_from_block(elif_branch.block_body):
|
|
1596
1832
|
if var not in seen:
|
|
1597
1833
|
seen.add(var)
|
|
1598
1834
|
vars_found.append(var)
|
|
1599
|
-
if cond.HasField("else_branch") and cond.else_branch.HasField("
|
|
1600
|
-
for var in self.
|
|
1835
|
+
if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
|
|
1836
|
+
for var in self._collect_variables_from_block(cond.else_branch.block_body):
|
|
1601
1837
|
if var not in seen:
|
|
1602
1838
|
seen.add(var)
|
|
1603
1839
|
vars_found.append(var)
|
|
@@ -1609,22 +1845,22 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1609
1845
|
if var not in seen:
|
|
1610
1846
|
seen.add(var)
|
|
1611
1847
|
vars_found.append(var)
|
|
1612
|
-
if fl.HasField("
|
|
1613
|
-
for var in self.
|
|
1848
|
+
if fl.HasField("block_body"):
|
|
1849
|
+
for var in self._collect_variables_from_block(fl.block_body):
|
|
1614
1850
|
if var not in seen:
|
|
1615
1851
|
seen.add(var)
|
|
1616
1852
|
vars_found.append(var)
|
|
1617
1853
|
|
|
1618
1854
|
if stmt.HasField("try_except"):
|
|
1619
1855
|
te = stmt.try_except
|
|
1620
|
-
if te.HasField("
|
|
1621
|
-
for var in self.
|
|
1856
|
+
if te.HasField("try_block"):
|
|
1857
|
+
for var in self._collect_variables_from_block(te.try_block):
|
|
1622
1858
|
if var not in seen:
|
|
1623
1859
|
seen.add(var)
|
|
1624
1860
|
vars_found.append(var)
|
|
1625
1861
|
for handler in te.handlers:
|
|
1626
|
-
if handler.HasField("
|
|
1627
|
-
for var in self.
|
|
1862
|
+
if handler.HasField("block_body"):
|
|
1863
|
+
for var in self._collect_variables_from_block(handler.block_body):
|
|
1628
1864
|
if var not in seen:
|
|
1629
1865
|
seen.add(var)
|
|
1630
1866
|
vars_found.append(var)
|
|
@@ -1661,66 +1897,13 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1661
1897
|
return vars_found
|
|
1662
1898
|
|
|
1663
1899
|
def _visit_try(self, node: ast.Try) -> List[ir.Statement]:
|
|
1664
|
-
"""Convert try/except to IR with
|
|
1665
|
-
|
|
1666
|
-
If the try body has multiple action calls, we wrap the entire body
|
|
1667
|
-
into a synthetic function, preserving exact semantics.
|
|
1668
|
-
|
|
1669
|
-
Python:
|
|
1670
|
-
try:
|
|
1671
|
-
a = await setup_action()
|
|
1672
|
-
b = await risky_action(a)
|
|
1673
|
-
return f"success:{b}"
|
|
1674
|
-
except SomeError:
|
|
1675
|
-
...
|
|
1676
|
-
|
|
1677
|
-
Becomes IR equivalent of:
|
|
1678
|
-
fn __try_body_1__():
|
|
1679
|
-
a = @setup_action()
|
|
1680
|
-
b = @risky_action(a=a)
|
|
1681
|
-
return f"success:{b}"
|
|
1682
|
-
|
|
1683
|
-
try:
|
|
1684
|
-
__try_body_1__()
|
|
1685
|
-
except SomeError:
|
|
1686
|
-
...
|
|
1687
|
-
"""
|
|
1900
|
+
"""Convert try/except to IR with full block bodies."""
|
|
1688
1901
|
# Build try body statements (recursively transforms nested structures)
|
|
1689
1902
|
try_body: List[ir.Statement] = []
|
|
1690
1903
|
for body_node in node.body:
|
|
1691
1904
|
stmts = self._visit_statement(body_node)
|
|
1692
1905
|
try_body.extend(stmts)
|
|
1693
1906
|
|
|
1694
|
-
# ALWAYS wrap try body for variable isolation
|
|
1695
|
-
assigned_vars_ordered = self._collect_assigned_vars_in_order(try_body)
|
|
1696
|
-
assigned_vars_set = set(assigned_vars_ordered)
|
|
1697
|
-
free_vars = [
|
|
1698
|
-
var
|
|
1699
|
-
for var in self._collect_variables_from_statements(try_body)
|
|
1700
|
-
if var not in assigned_vars_set
|
|
1701
|
-
]
|
|
1702
|
-
modified_vars = self._detect_accumulator_targets(try_body, assigned_vars_set)
|
|
1703
|
-
|
|
1704
|
-
# Inputs need free variables plus any accumulator-style mutations.
|
|
1705
|
-
try_inputs = []
|
|
1706
|
-
for var in free_vars + modified_vars:
|
|
1707
|
-
if var not in try_inputs:
|
|
1708
|
-
try_inputs.append(var)
|
|
1709
|
-
|
|
1710
|
-
# Outputs include all assigned variables plus accumulator targets.
|
|
1711
|
-
try_outputs: list[str] = []
|
|
1712
|
-
for var in assigned_vars_ordered + modified_vars:
|
|
1713
|
-
if var not in try_outputs:
|
|
1714
|
-
try_outputs.append(var)
|
|
1715
|
-
|
|
1716
|
-
try_body = self._wrap_body_as_function(
|
|
1717
|
-
try_body,
|
|
1718
|
-
"try_body",
|
|
1719
|
-
node,
|
|
1720
|
-
inputs=try_inputs,
|
|
1721
|
-
modified_vars=try_outputs,
|
|
1722
|
-
)
|
|
1723
|
-
|
|
1724
1907
|
# Build exception handlers (with wrapping if needed)
|
|
1725
1908
|
handlers: List[ir.ExceptHandler] = []
|
|
1726
1909
|
for handler in node.handlers:
|
|
@@ -1739,46 +1922,20 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1739
1922
|
stmts = self._visit_statement(handler_node)
|
|
1740
1923
|
handler_body.extend(stmts)
|
|
1741
1924
|
|
|
1742
|
-
# ALWAYS wrap handler body for variable isolation
|
|
1743
|
-
assigned_vars_ordered = self._collect_assigned_vars_in_order(handler_body)
|
|
1744
|
-
assigned_vars_set = set(assigned_vars_ordered)
|
|
1745
|
-
free_vars = [
|
|
1746
|
-
var
|
|
1747
|
-
for var in self._collect_variables_from_statements(handler_body)
|
|
1748
|
-
if var not in assigned_vars_set
|
|
1749
|
-
]
|
|
1750
|
-
modified_vars = self._detect_accumulator_targets(handler_body, assigned_vars_set)
|
|
1751
|
-
|
|
1752
|
-
handler_inputs: list[str] = []
|
|
1753
|
-
for var in free_vars + modified_vars:
|
|
1754
|
-
if var not in handler_inputs:
|
|
1755
|
-
handler_inputs.append(var)
|
|
1756
|
-
|
|
1757
|
-
handler_outputs: list[str] = []
|
|
1758
|
-
for var in assigned_vars_ordered + modified_vars:
|
|
1759
|
-
if var not in handler_outputs:
|
|
1760
|
-
handler_outputs.append(var)
|
|
1761
|
-
|
|
1762
|
-
handler_body = self._wrap_body_as_function(
|
|
1763
|
-
handler_body,
|
|
1764
|
-
"except_handler",
|
|
1765
|
-
node,
|
|
1766
|
-
inputs=handler_inputs,
|
|
1767
|
-
modified_vars=handler_outputs,
|
|
1768
|
-
)
|
|
1769
|
-
|
|
1770
1925
|
except_handler = ir.ExceptHandler(
|
|
1771
1926
|
exception_types=exception_types,
|
|
1772
|
-
|
|
1927
|
+
block_body=ir.Block(statements=handler_body, span=_make_span(handler)),
|
|
1773
1928
|
span=_make_span(handler),
|
|
1774
1929
|
)
|
|
1930
|
+
if handler.name:
|
|
1931
|
+
except_handler.exception_var = handler.name
|
|
1775
1932
|
handlers.append(except_handler)
|
|
1776
1933
|
|
|
1777
1934
|
# Build the try/except statement
|
|
1778
1935
|
try_stmt = ir.Statement(span=_make_span(node))
|
|
1779
1936
|
try_except = ir.TryExcept(
|
|
1780
|
-
try_body=self._stmts_to_single_call_body(try_body, _make_span(node)),
|
|
1781
1937
|
handlers=handlers,
|
|
1938
|
+
try_block=ir.Block(statements=try_body, span=_make_span(node)),
|
|
1782
1939
|
)
|
|
1783
1940
|
try_stmt.try_except.CopyFrom(try_except)
|
|
1784
1941
|
|
|
@@ -1806,56 +1963,6 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1806
1963
|
count += 1
|
|
1807
1964
|
return count
|
|
1808
1965
|
|
|
1809
|
-
def _stmts_to_single_call_body(
|
|
1810
|
-
self, stmts: List[ir.Statement], span: ir.Span
|
|
1811
|
-
) -> ir.SingleCallBody:
|
|
1812
|
-
"""Convert statements to SingleCallBody.
|
|
1813
|
-
|
|
1814
|
-
Can contain EITHER:
|
|
1815
|
-
1. A single action or function call (with optional target)
|
|
1816
|
-
2. Pure data statements (no calls)
|
|
1817
|
-
"""
|
|
1818
|
-
body = ir.SingleCallBody(span=span)
|
|
1819
|
-
|
|
1820
|
-
# Look for a single call in the statements
|
|
1821
|
-
for stmt in stmts:
|
|
1822
|
-
if stmt.HasField("action_call"):
|
|
1823
|
-
# ActionCall as a statement has no target (side-effect only)
|
|
1824
|
-
action = stmt.action_call
|
|
1825
|
-
call = ir.Call()
|
|
1826
|
-
call.action.CopyFrom(action)
|
|
1827
|
-
body.call.CopyFrom(call)
|
|
1828
|
-
return body
|
|
1829
|
-
elif stmt.HasField("assignment"):
|
|
1830
|
-
# Check if assignment value is an action call or function call
|
|
1831
|
-
if stmt.assignment.value.HasField("action_call"):
|
|
1832
|
-
action = stmt.assignment.value.action_call
|
|
1833
|
-
# Copy all targets for tuple unpacking support
|
|
1834
|
-
body.targets.extend(stmt.assignment.targets)
|
|
1835
|
-
call = ir.Call()
|
|
1836
|
-
call.action.CopyFrom(action)
|
|
1837
|
-
body.call.CopyFrom(call)
|
|
1838
|
-
return body
|
|
1839
|
-
elif stmt.assignment.value.HasField("function_call"):
|
|
1840
|
-
fn_call = stmt.assignment.value.function_call
|
|
1841
|
-
# Copy all targets for tuple unpacking support
|
|
1842
|
-
body.targets.extend(stmt.assignment.targets)
|
|
1843
|
-
call = ir.Call()
|
|
1844
|
-
call.function.CopyFrom(fn_call)
|
|
1845
|
-
body.call.CopyFrom(call)
|
|
1846
|
-
return body
|
|
1847
|
-
elif stmt.HasField("expr_stmt") and stmt.expr_stmt.expr.HasField("function_call"):
|
|
1848
|
-
fn_call = stmt.expr_stmt.expr.function_call
|
|
1849
|
-
call = ir.Call()
|
|
1850
|
-
call.function.CopyFrom(fn_call)
|
|
1851
|
-
body.call.CopyFrom(call)
|
|
1852
|
-
return body
|
|
1853
|
-
|
|
1854
|
-
# No call found - this is a pure data body
|
|
1855
|
-
# Add all statements as pure data
|
|
1856
|
-
body.statements.extend(stmts)
|
|
1857
|
-
return body
|
|
1858
|
-
|
|
1859
1966
|
def _wrap_body_as_function(
|
|
1860
1967
|
self,
|
|
1861
1968
|
body: List[ir.Statement],
|
|
@@ -1976,8 +2083,20 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1976
2083
|
|
|
1977
2084
|
return [assign_stmt, return_stmt]
|
|
1978
2085
|
|
|
2086
|
+
# Normalize return of function calls into assignment + return
|
|
2087
|
+
expr = self._expr_to_ir_with_model_coercion(node.value)
|
|
2088
|
+
if expr and expr.HasField("function_call"):
|
|
2089
|
+
tmp_var = self._ctx.next_implicit_fn_name(prefix="return_tmp")
|
|
2090
|
+
|
|
2091
|
+
assign_stmt = ir.Statement(span=_make_span(node))
|
|
2092
|
+
assign_stmt.assignment.CopyFrom(ir.Assignment(targets=[tmp_var], value=expr))
|
|
2093
|
+
|
|
2094
|
+
return_stmt = ir.Statement(span=_make_span(node))
|
|
2095
|
+
var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
|
|
2096
|
+
return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=var_expr))
|
|
2097
|
+
return [assign_stmt, return_stmt]
|
|
2098
|
+
|
|
1979
2099
|
# Regular return with expression (variable, literal, etc.)
|
|
1980
|
-
expr = _expr_to_ir(node.value)
|
|
1981
2100
|
if expr:
|
|
1982
2101
|
stmt = ir.Statement(span=_make_span(node))
|
|
1983
2102
|
return_stmt = ir.ReturnStmt(value=expr)
|
|
@@ -1989,7 +2108,19 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1989
2108
|
stmt.return_stmt.CopyFrom(ir.ReturnStmt())
|
|
1990
2109
|
return [stmt]
|
|
1991
2110
|
|
|
1992
|
-
def
|
|
2111
|
+
def _visit_break(self, node: ast.Break) -> List[ir.Statement]:
|
|
2112
|
+
"""Convert break statement to IR."""
|
|
2113
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
2114
|
+
stmt.break_stmt.CopyFrom(ir.BreakStmt())
|
|
2115
|
+
return [stmt]
|
|
2116
|
+
|
|
2117
|
+
def _visit_continue(self, node: ast.Continue) -> List[ir.Statement]:
|
|
2118
|
+
"""Convert continue statement to IR."""
|
|
2119
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
2120
|
+
stmt.continue_stmt.CopyFrom(ir.ContinueStmt())
|
|
2121
|
+
return [stmt]
|
|
2122
|
+
|
|
2123
|
+
def _visit_aug_assign(self, node: ast.AugAssign) -> List[ir.Statement]:
|
|
1993
2124
|
"""Convert augmented assignment (+=, -=, etc.) to IR."""
|
|
1994
2125
|
# For now, we can represent this as a regular assignment with binary op
|
|
1995
2126
|
# target op= value -> target = target op value
|
|
@@ -1999,8 +2130,33 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
1999
2130
|
if isinstance(node.target, ast.Name):
|
|
2000
2131
|
targets.append(node.target.id)
|
|
2001
2132
|
|
|
2002
|
-
left =
|
|
2003
|
-
right =
|
|
2133
|
+
left = self._expr_to_ir_with_model_coercion(node.target)
|
|
2134
|
+
right = self._expr_to_ir_with_model_coercion(node.value)
|
|
2135
|
+
if right and right.HasField("function_call"):
|
|
2136
|
+
tmp_var = self._ctx.next_implicit_fn_name(prefix="aug_tmp")
|
|
2137
|
+
|
|
2138
|
+
assign_tmp = ir.Statement(span=_make_span(node))
|
|
2139
|
+
assign_tmp.assignment.CopyFrom(
|
|
2140
|
+
ir.Assignment(
|
|
2141
|
+
targets=[tmp_var],
|
|
2142
|
+
value=ir.Expr(function_call=right.function_call, span=_make_span(node)),
|
|
2143
|
+
)
|
|
2144
|
+
)
|
|
2145
|
+
|
|
2146
|
+
if left:
|
|
2147
|
+
op = _bin_op_to_ir(node.op)
|
|
2148
|
+
if op:
|
|
2149
|
+
binary = ir.BinaryOp(
|
|
2150
|
+
left=left,
|
|
2151
|
+
op=op,
|
|
2152
|
+
right=ir.Expr(variable=ir.Variable(name=tmp_var)),
|
|
2153
|
+
)
|
|
2154
|
+
value = ir.Expr(binary_op=binary)
|
|
2155
|
+
assign = ir.Assignment(targets=targets, value=value)
|
|
2156
|
+
stmt.assignment.CopyFrom(assign)
|
|
2157
|
+
return [assign_tmp, stmt]
|
|
2158
|
+
return [assign_tmp]
|
|
2159
|
+
|
|
2004
2160
|
if left and right:
|
|
2005
2161
|
op = _bin_op_to_ir(node.op)
|
|
2006
2162
|
if op:
|
|
@@ -2008,9 +2164,38 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2008
2164
|
value = ir.Expr(binary_op=binary)
|
|
2009
2165
|
assign = ir.Assignment(targets=targets, value=value)
|
|
2010
2166
|
stmt.assignment.CopyFrom(assign)
|
|
2011
|
-
return stmt
|
|
2167
|
+
return [stmt]
|
|
2012
2168
|
|
|
2013
|
-
return
|
|
2169
|
+
return []
|
|
2170
|
+
|
|
2171
|
+
def _collect_function_inputs(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> List[str]:
|
|
2172
|
+
"""Collect workflow inputs from function parameters, including kw-only args."""
|
|
2173
|
+
args: List[str] = []
|
|
2174
|
+
seen: set[str] = set()
|
|
2175
|
+
|
|
2176
|
+
ordered_args = list(node.args.posonlyargs) + list(node.args.args)
|
|
2177
|
+
if ordered_args and ordered_args[0].arg == "self":
|
|
2178
|
+
ordered_args = ordered_args[1:]
|
|
2179
|
+
|
|
2180
|
+
for arg in ordered_args:
|
|
2181
|
+
if arg.arg not in seen:
|
|
2182
|
+
args.append(arg.arg)
|
|
2183
|
+
seen.add(arg.arg)
|
|
2184
|
+
|
|
2185
|
+
if node.args.vararg and node.args.vararg.arg not in seen:
|
|
2186
|
+
args.append(node.args.vararg.arg)
|
|
2187
|
+
seen.add(node.args.vararg.arg)
|
|
2188
|
+
|
|
2189
|
+
for arg in node.args.kwonlyargs:
|
|
2190
|
+
if arg.arg not in seen:
|
|
2191
|
+
args.append(arg.arg)
|
|
2192
|
+
seen.add(arg.arg)
|
|
2193
|
+
|
|
2194
|
+
if node.args.kwarg and node.args.kwarg.arg not in seen:
|
|
2195
|
+
args.append(node.args.kwarg.arg)
|
|
2196
|
+
seen.add(node.args.kwarg.arg)
|
|
2197
|
+
|
|
2198
|
+
return args
|
|
2014
2199
|
|
|
2015
2200
|
def _check_constructor_in_return(self, node: ast.expr) -> None:
|
|
2016
2201
|
"""Check for constructor calls in return statements.
|
|
@@ -2155,7 +2340,7 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2155
2340
|
key_literal.string_value = kw.arg
|
|
2156
2341
|
key_expr.literal.CopyFrom(key_literal)
|
|
2157
2342
|
|
|
2158
|
-
value_expr =
|
|
2343
|
+
value_expr = self._expr_to_ir_with_model_coercion(kw.value)
|
|
2159
2344
|
if value_expr is None:
|
|
2160
2345
|
# If we can't convert the value, we need to raise an error
|
|
2161
2346
|
line = getattr(node, "lineno", None)
|
|
@@ -2192,7 +2377,7 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2192
2377
|
key_literal.string_value = field_name
|
|
2193
2378
|
key_expr.literal.CopyFrom(key_literal)
|
|
2194
2379
|
|
|
2195
|
-
value_expr =
|
|
2380
|
+
value_expr = self._expr_to_ir_with_model_coercion(arg)
|
|
2196
2381
|
if value_expr is None:
|
|
2197
2382
|
line = getattr(node, "lineno", None)
|
|
2198
2383
|
col = getattr(node, "col_offset", None)
|
|
@@ -2410,7 +2595,19 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2410
2595
|
- RetryPolicy(attempts=3)
|
|
2411
2596
|
- RetryPolicy(attempts=3, exception_types=["ValueError"])
|
|
2412
2597
|
- RetryPolicy(attempts=3, backoff_seconds=5)
|
|
2598
|
+
- self.retry_policy (instance attribute reference)
|
|
2413
2599
|
"""
|
|
2600
|
+
# Handle self.attr pattern - look up in instance attrs
|
|
2601
|
+
if (
|
|
2602
|
+
isinstance(node, ast.Attribute)
|
|
2603
|
+
and isinstance(node.value, ast.Name)
|
|
2604
|
+
and node.value.id == "self"
|
|
2605
|
+
):
|
|
2606
|
+
attr_name = node.attr
|
|
2607
|
+
if attr_name in self._instance_attrs:
|
|
2608
|
+
return self._parse_retry_policy(self._instance_attrs[attr_name])
|
|
2609
|
+
return None
|
|
2610
|
+
|
|
2414
2611
|
if not isinstance(node, ast.Call):
|
|
2415
2612
|
return None
|
|
2416
2613
|
|
|
@@ -2428,7 +2625,9 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2428
2625
|
|
|
2429
2626
|
for kw in node.keywords:
|
|
2430
2627
|
if kw.arg == "attempts" and isinstance(kw.value, ast.Constant):
|
|
2431
|
-
|
|
2628
|
+
# attempts means total executions, max_retries means retries after first attempt
|
|
2629
|
+
# So attempts=1 -> max_retries=0 (no retries), attempts=3 -> max_retries=2
|
|
2630
|
+
policy.max_retries = kw.value.value - 1
|
|
2432
2631
|
elif kw.arg == "exception_types" and isinstance(kw.value, ast.List):
|
|
2433
2632
|
for elt in kw.value.elts:
|
|
2434
2633
|
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
|
|
@@ -2446,7 +2645,19 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2446
2645
|
- timeout=30.5 (float seconds)
|
|
2447
2646
|
- timeout=timedelta(seconds=30)
|
|
2448
2647
|
- timeout=timedelta(minutes=2)
|
|
2648
|
+
- self.timeout (instance attribute reference)
|
|
2449
2649
|
"""
|
|
2650
|
+
# Handle self.attr pattern - look up in instance attrs
|
|
2651
|
+
if (
|
|
2652
|
+
isinstance(node, ast.Attribute)
|
|
2653
|
+
and isinstance(node.value, ast.Name)
|
|
2654
|
+
and node.value.id == "self"
|
|
2655
|
+
):
|
|
2656
|
+
attr_name = node.attr
|
|
2657
|
+
if attr_name in self._instance_attrs:
|
|
2658
|
+
return self._parse_timeout_policy(self._instance_attrs[attr_name])
|
|
2659
|
+
return None
|
|
2660
|
+
|
|
2450
2661
|
policy = ir.TimeoutPolicy()
|
|
2451
2662
|
|
|
2452
2663
|
# Direct numeric value (seconds)
|
|
@@ -2511,14 +2722,14 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2511
2722
|
# Extract duration argument (positional or keyword)
|
|
2512
2723
|
if node.args:
|
|
2513
2724
|
# asyncio.sleep(1) - positional
|
|
2514
|
-
expr =
|
|
2725
|
+
expr = self._expr_to_ir_with_model_coercion(node.args[0])
|
|
2515
2726
|
if expr:
|
|
2516
2727
|
action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
|
|
2517
2728
|
elif node.keywords:
|
|
2518
2729
|
# asyncio.sleep(seconds=1) - keyword (less common)
|
|
2519
2730
|
for kw in node.keywords:
|
|
2520
2731
|
if kw.arg in ("seconds", "delay", "duration"):
|
|
2521
|
-
expr =
|
|
2732
|
+
expr = self._expr_to_ir_with_model_coercion(kw.value)
|
|
2522
2733
|
if expr:
|
|
2523
2734
|
action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
|
|
2524
2735
|
break
|
|
@@ -2606,9 +2817,10 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2606
2817
|
|
|
2607
2818
|
Handles patterns like:
|
|
2608
2819
|
[action(x=item) for item in collection]
|
|
2820
|
+
[self.run_action(action(x=item), retry=..., timeout=...) for item in collection]
|
|
2609
2821
|
|
|
2610
2822
|
The comprehension must have exactly one generator with no conditions,
|
|
2611
|
-
and the element must be an action call.
|
|
2823
|
+
and the element must be an action call (optionally wrapped in run_action).
|
|
2612
2824
|
|
|
2613
2825
|
Args:
|
|
2614
2826
|
listcomp: The ListComp AST node
|
|
@@ -2653,7 +2865,7 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2653
2865
|
loop_var = gen.target.id
|
|
2654
2866
|
|
|
2655
2867
|
# Get the collection expression
|
|
2656
|
-
collection_expr =
|
|
2868
|
+
collection_expr = self._expr_to_ir_with_model_coercion(gen.iter)
|
|
2657
2869
|
if not collection_expr:
|
|
2658
2870
|
line = getattr(listcomp, "lineno", None)
|
|
2659
2871
|
col = getattr(listcomp, "col_offset", None)
|
|
@@ -2664,7 +2876,7 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2664
2876
|
col=col,
|
|
2665
2877
|
)
|
|
2666
2878
|
|
|
2667
|
-
# The element must be
|
|
2879
|
+
# The element must be a call (either action call or run_action wrapper)
|
|
2668
2880
|
if not isinstance(listcomp.elt, ast.Call):
|
|
2669
2881
|
line = getattr(listcomp, "lineno", None)
|
|
2670
2882
|
col = getattr(listcomp, "col_offset", None)
|
|
@@ -2675,13 +2887,27 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2675
2887
|
col=col,
|
|
2676
2888
|
)
|
|
2677
2889
|
|
|
2678
|
-
|
|
2890
|
+
# Check for self.run_action(...) wrapper pattern
|
|
2891
|
+
action_call: Optional[ir.ActionCall] = None
|
|
2892
|
+
if self._is_run_action_call(listcomp.elt):
|
|
2893
|
+
# Extract the inner action call from run_action's first argument
|
|
2894
|
+
if listcomp.elt.args:
|
|
2895
|
+
inner_call = listcomp.elt.args[0]
|
|
2896
|
+
if isinstance(inner_call, ast.Call):
|
|
2897
|
+
action_call = self._extract_action_call_from_call(inner_call)
|
|
2898
|
+
if action_call:
|
|
2899
|
+
# Extract policies (retry, timeout) from run_action kwargs
|
|
2900
|
+
self._extract_policies_from_run_action(listcomp.elt, action_call)
|
|
2901
|
+
else:
|
|
2902
|
+
# Direct action call
|
|
2903
|
+
action_call = self._extract_action_call_from_call(listcomp.elt)
|
|
2904
|
+
|
|
2679
2905
|
if not action_call:
|
|
2680
2906
|
line = getattr(listcomp, "lineno", None)
|
|
2681
2907
|
col = getattr(listcomp, "col_offset", None)
|
|
2682
2908
|
raise UnsupportedPatternError(
|
|
2683
2909
|
"Spread pattern element must be an @action call",
|
|
2684
|
-
"Ensure the function is decorated with @action",
|
|
2910
|
+
"Ensure the function is decorated with @action, or use self.run_action(action(...), ...)",
|
|
2685
2911
|
line=line,
|
|
2686
2912
|
col=col,
|
|
2687
2913
|
)
|
|
@@ -2725,20 +2951,26 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2725
2951
|
return None
|
|
2726
2952
|
|
|
2727
2953
|
fn_call = ir.FunctionCall(name=func_name)
|
|
2954
|
+
global_function = _global_function_for_call(func_name, node)
|
|
2955
|
+
if global_function is not None:
|
|
2956
|
+
fn_call.global_function = global_function
|
|
2728
2957
|
|
|
2729
2958
|
# Add positional args
|
|
2730
2959
|
for arg in node.args:
|
|
2731
|
-
expr =
|
|
2960
|
+
expr = self._expr_to_ir_with_model_coercion(arg)
|
|
2732
2961
|
if expr:
|
|
2733
2962
|
fn_call.args.append(expr)
|
|
2734
2963
|
|
|
2735
2964
|
# Add keyword args
|
|
2736
2965
|
for kw in node.keywords:
|
|
2737
2966
|
if kw.arg:
|
|
2738
|
-
expr =
|
|
2967
|
+
expr = self._expr_to_ir_with_model_coercion(kw.value)
|
|
2739
2968
|
if expr:
|
|
2740
2969
|
fn_call.kwargs.append(ir.Kwarg(name=kw.arg, value=expr))
|
|
2741
2970
|
|
|
2971
|
+
# Fill in missing kwargs with default values from function signature
|
|
2972
|
+
self._fill_default_kwargs_for_expr(fn_call)
|
|
2973
|
+
|
|
2742
2974
|
return fn_call
|
|
2743
2975
|
|
|
2744
2976
|
def _get_func_name(self, node: ast.expr) -> Optional[str]:
|
|
@@ -2754,27 +2986,147 @@ class IRBuilder(ast.NodeVisitor):
|
|
|
2754
2986
|
current = current.value
|
|
2755
2987
|
if isinstance(current, ast.Name):
|
|
2756
2988
|
parts.append(current.id)
|
|
2757
|
-
|
|
2989
|
+
name = ".".join(reversed(parts))
|
|
2990
|
+
if name.startswith("self."):
|
|
2991
|
+
return name[5:]
|
|
2992
|
+
return name
|
|
2758
2993
|
return None
|
|
2759
2994
|
|
|
2760
|
-
def
|
|
2761
|
-
|
|
2995
|
+
def _convert_model_constructor_if_needed(self, node: ast.Call) -> Optional[ir.Expr]:
|
|
2996
|
+
model_name = self._is_model_constructor(node)
|
|
2997
|
+
if model_name:
|
|
2998
|
+
return self._convert_model_constructor_to_dict(node, model_name)
|
|
2999
|
+
return None
|
|
2762
3000
|
|
|
2763
|
-
|
|
2764
|
-
|
|
3001
|
+
def _resolve_enum_attribute(self, node: ast.Attribute) -> Optional[ir.Expr]:
|
|
3002
|
+
value = _resolve_enum_attribute_value(node, self._module_globals)
|
|
3003
|
+
if value is None:
|
|
3004
|
+
return None
|
|
3005
|
+
literal = _constant_to_literal(value)
|
|
3006
|
+
if literal is None:
|
|
3007
|
+
line = getattr(node, "lineno", None)
|
|
3008
|
+
col = getattr(node, "col_offset", None)
|
|
3009
|
+
raise UnsupportedPatternError(
|
|
3010
|
+
"Enum value must be a primitive literal",
|
|
3011
|
+
RECOMMENDATIONS["unsupported_literal"],
|
|
3012
|
+
line=line,
|
|
3013
|
+
col=col,
|
|
3014
|
+
)
|
|
3015
|
+
expr = ir.Expr(span=_make_span(node))
|
|
3016
|
+
expr.literal.CopyFrom(literal)
|
|
3017
|
+
return expr
|
|
3018
|
+
|
|
3019
|
+
def _is_exception_class(self, class_name: str) -> bool:
|
|
3020
|
+
"""Check if a class name refers to an exception class.
|
|
2765
3021
|
|
|
2766
|
-
|
|
2767
|
-
it is converted to a dict expression. Otherwise, falls back to the
|
|
2768
|
-
standard _expr_to_ir conversion.
|
|
3022
|
+
This checks both built-in exceptions and imported exception classes.
|
|
2769
3023
|
"""
|
|
2770
|
-
# Check
|
|
2771
|
-
|
|
2772
|
-
|
|
2773
|
-
if
|
|
2774
|
-
|
|
3024
|
+
# Check built-in exceptions first
|
|
3025
|
+
builtin_exc = (
|
|
3026
|
+
getattr(__builtins__, class_name, None)
|
|
3027
|
+
if isinstance(__builtins__, dict)
|
|
3028
|
+
else getattr(__builtins__, class_name, None)
|
|
3029
|
+
)
|
|
3030
|
+
if builtin_exc is None:
|
|
3031
|
+
# Try getting from builtins module directly
|
|
3032
|
+
import builtins
|
|
3033
|
+
|
|
3034
|
+
builtin_exc = getattr(builtins, class_name, None)
|
|
3035
|
+
if (
|
|
3036
|
+
builtin_exc is not None
|
|
3037
|
+
and isinstance(builtin_exc, type)
|
|
3038
|
+
and issubclass(builtin_exc, BaseException)
|
|
3039
|
+
):
|
|
3040
|
+
return True
|
|
3041
|
+
|
|
3042
|
+
# Check module globals for imported exception classes
|
|
3043
|
+
cls = self._module_globals.get(class_name)
|
|
3044
|
+
if cls is not None and isinstance(cls, type) and issubclass(cls, BaseException):
|
|
3045
|
+
return True
|
|
3046
|
+
|
|
3047
|
+
return False
|
|
3048
|
+
|
|
3049
|
+
def _expr_to_ir_with_model_coercion(self, node: ast.expr) -> Optional[ir.Expr]:
|
|
3050
|
+
"""Convert an AST expression to IR, converting model constructors to dicts."""
|
|
3051
|
+
result = _expr_to_ir(
|
|
3052
|
+
node,
|
|
3053
|
+
model_converter=self._convert_model_constructor_if_needed,
|
|
3054
|
+
enum_resolver=self._resolve_enum_attribute,
|
|
3055
|
+
exception_class_resolver=self._is_exception_class,
|
|
3056
|
+
)
|
|
3057
|
+
# Post-process to fill in default kwargs for function calls (recursively)
|
|
3058
|
+
if result is not None:
|
|
3059
|
+
self._fill_default_kwargs_recursive(result)
|
|
3060
|
+
return result
|
|
3061
|
+
|
|
3062
|
+
def _fill_default_kwargs_recursive(self, expr: ir.Expr) -> None:
|
|
3063
|
+
"""Recursively fill in default kwargs for all function calls in an expression."""
|
|
3064
|
+
if expr.HasField("function_call"):
|
|
3065
|
+
self._fill_default_kwargs_for_expr(expr.function_call)
|
|
3066
|
+
# Recurse into function call args and kwargs
|
|
3067
|
+
for arg in expr.function_call.args:
|
|
3068
|
+
self._fill_default_kwargs_recursive(arg)
|
|
3069
|
+
for kwarg in expr.function_call.kwargs:
|
|
3070
|
+
if kwarg.value:
|
|
3071
|
+
self._fill_default_kwargs_recursive(kwarg.value)
|
|
3072
|
+
elif expr.HasField("binary_op"):
|
|
3073
|
+
if expr.binary_op.left:
|
|
3074
|
+
self._fill_default_kwargs_recursive(expr.binary_op.left)
|
|
3075
|
+
if expr.binary_op.right:
|
|
3076
|
+
self._fill_default_kwargs_recursive(expr.binary_op.right)
|
|
3077
|
+
elif expr.HasField("unary_op"):
|
|
3078
|
+
if expr.unary_op.operand:
|
|
3079
|
+
self._fill_default_kwargs_recursive(expr.unary_op.operand)
|
|
3080
|
+
elif expr.HasField("list"):
|
|
3081
|
+
for elem in expr.list.elements:
|
|
3082
|
+
self._fill_default_kwargs_recursive(elem)
|
|
3083
|
+
elif expr.HasField("dict"):
|
|
3084
|
+
for entry in expr.dict.entries:
|
|
3085
|
+
if entry.key:
|
|
3086
|
+
self._fill_default_kwargs_recursive(entry.key)
|
|
3087
|
+
if entry.value:
|
|
3088
|
+
self._fill_default_kwargs_recursive(entry.value)
|
|
3089
|
+
elif expr.HasField("index"):
|
|
3090
|
+
if expr.index.object:
|
|
3091
|
+
self._fill_default_kwargs_recursive(expr.index.object)
|
|
3092
|
+
if expr.index.index:
|
|
3093
|
+
self._fill_default_kwargs_recursive(expr.index.index)
|
|
3094
|
+
elif expr.HasField("dot"):
|
|
3095
|
+
if expr.dot.object:
|
|
3096
|
+
self._fill_default_kwargs_recursive(expr.dot.object)
|
|
2775
3097
|
|
|
2776
|
-
|
|
2777
|
-
|
|
3098
|
+
def _fill_default_kwargs_for_expr(self, fn_call: ir.FunctionCall) -> None:
|
|
3099
|
+
"""Fill in missing kwargs with default values from the function signature."""
|
|
3100
|
+
sig_info = self._ctx.function_signatures.get(fn_call.name)
|
|
3101
|
+
if sig_info is None:
|
|
3102
|
+
return
|
|
3103
|
+
|
|
3104
|
+
# Track which parameters are already provided
|
|
3105
|
+
provided_by_position: Set[str] = set()
|
|
3106
|
+
provided_by_kwarg: Set[str] = set()
|
|
3107
|
+
|
|
3108
|
+
# Positional args map to parameters in order
|
|
3109
|
+
for idx, _arg in enumerate(fn_call.args):
|
|
3110
|
+
if idx < len(sig_info.parameters):
|
|
3111
|
+
provided_by_position.add(sig_info.parameters[idx].name)
|
|
3112
|
+
|
|
3113
|
+
# Kwargs are named
|
|
3114
|
+
for kwarg in fn_call.kwargs:
|
|
3115
|
+
provided_by_kwarg.add(kwarg.name)
|
|
3116
|
+
|
|
3117
|
+
# Add defaults for missing parameters
|
|
3118
|
+
for param in sig_info.parameters:
|
|
3119
|
+
if param.name in provided_by_position or param.name in provided_by_kwarg:
|
|
3120
|
+
continue
|
|
3121
|
+
if not param.has_default:
|
|
3122
|
+
continue
|
|
3123
|
+
|
|
3124
|
+
# Convert the default value to an IR expression
|
|
3125
|
+
literal = _constant_to_literal(param.default_value)
|
|
3126
|
+
if literal is not None:
|
|
3127
|
+
expr = ir.Expr()
|
|
3128
|
+
expr.literal.CopyFrom(literal)
|
|
3129
|
+
fn_call.kwargs.append(ir.Kwarg(name=param.name, value=expr))
|
|
2778
3130
|
|
|
2779
3131
|
def _extract_action_call_from_awaitable(self, node: ast.expr) -> Optional[ir.ActionCall]:
|
|
2780
3132
|
"""Extract action call from an awaitable expression."""
|
|
@@ -2869,10 +3221,200 @@ def _make_span(node: ast.AST) -> ir.Span:
|
|
|
2869
3221
|
)
|
|
2870
3222
|
|
|
2871
3223
|
|
|
2872
|
-
def
|
|
2873
|
-
|
|
3224
|
+
def _attribute_chain(node: ast.Attribute) -> Optional[List[str]]:
|
|
3225
|
+
parts: List[str] = []
|
|
3226
|
+
current: ast.AST = node
|
|
3227
|
+
while isinstance(current, ast.Attribute):
|
|
3228
|
+
parts.append(current.attr)
|
|
3229
|
+
current = current.value
|
|
3230
|
+
if isinstance(current, ast.Name):
|
|
3231
|
+
parts.append(current.id)
|
|
3232
|
+
return list(reversed(parts))
|
|
3233
|
+
return None
|
|
3234
|
+
|
|
3235
|
+
|
|
3236
|
+
def _resolve_enum_attribute_value(
|
|
3237
|
+
node: ast.Attribute,
|
|
3238
|
+
module_globals: Mapping[str, Any],
|
|
3239
|
+
) -> Optional[Any]:
|
|
3240
|
+
chain = _attribute_chain(node)
|
|
3241
|
+
if not chain or len(chain) < 2:
|
|
3242
|
+
return None
|
|
3243
|
+
|
|
3244
|
+
current = module_globals.get(chain[0])
|
|
3245
|
+
if current is None:
|
|
3246
|
+
return None
|
|
3247
|
+
|
|
3248
|
+
for part in chain[1:-1]:
|
|
3249
|
+
try:
|
|
3250
|
+
current_dict = current.__dict__
|
|
3251
|
+
except AttributeError:
|
|
3252
|
+
return None
|
|
3253
|
+
current = current_dict.get(part)
|
|
3254
|
+
if current is None:
|
|
3255
|
+
return None
|
|
3256
|
+
|
|
3257
|
+
member_name = chain[-1]
|
|
3258
|
+
if isinstance(current, EnumMeta):
|
|
3259
|
+
member = current.__members__.get(member_name)
|
|
3260
|
+
if member is None:
|
|
3261
|
+
return None
|
|
3262
|
+
return member.value
|
|
3263
|
+
|
|
3264
|
+
return None
|
|
3265
|
+
|
|
3266
|
+
|
|
3267
|
+
def _try_convert_isinstance_to_isexception(
|
|
3268
|
+
expr: ast.Call,
|
|
3269
|
+
exception_class_resolver: Callable[[str], bool],
|
|
3270
|
+
model_converter: Optional[Callable[[ast.Call], Optional[ir.Expr]]] = None,
|
|
3271
|
+
enum_resolver: Optional[Callable[[ast.Attribute], Optional[ir.Expr]]] = None,
|
|
3272
|
+
) -> Optional[ir.Expr]:
|
|
3273
|
+
"""Try to convert isinstance(x, ExceptionClass) to isexception(x, "ExceptionClass").
|
|
3274
|
+
|
|
3275
|
+
Returns None if this is not an isinstance call or if the class is not an exception.
|
|
3276
|
+
Raises UnsupportedPatternError if isinstance is used with a non-exception class.
|
|
3277
|
+
"""
|
|
3278
|
+
# Check if this is an isinstance call
|
|
3279
|
+
func_name = _get_func_name(expr.func)
|
|
3280
|
+
if func_name != "isinstance":
|
|
3281
|
+
return None
|
|
3282
|
+
|
|
3283
|
+
# isinstance requires exactly 2 positional arguments
|
|
3284
|
+
if len(expr.args) != 2 or expr.keywords:
|
|
3285
|
+
line = expr.lineno if hasattr(expr, "lineno") else None
|
|
3286
|
+
col = expr.col_offset if hasattr(expr, "col_offset") else None
|
|
3287
|
+
raise UnsupportedPatternError(
|
|
3288
|
+
"isinstance() requires exactly 2 positional arguments",
|
|
3289
|
+
RECOMMENDATIONS["sync_function_call"],
|
|
3290
|
+
line=line,
|
|
3291
|
+
col=col,
|
|
3292
|
+
)
|
|
3293
|
+
|
|
3294
|
+
# Extract exception class names from the second argument
|
|
3295
|
+
second_arg = expr.args[1]
|
|
3296
|
+
exception_names: List[str] = []
|
|
3297
|
+
|
|
3298
|
+
if isinstance(second_arg, ast.Name):
|
|
3299
|
+
# Single class: isinstance(x, ValueError)
|
|
3300
|
+
class_name = second_arg.id
|
|
3301
|
+
if not exception_class_resolver(class_name):
|
|
3302
|
+
line = expr.lineno if hasattr(expr, "lineno") else None
|
|
3303
|
+
col = expr.col_offset if hasattr(expr, "col_offset") else None
|
|
3304
|
+
raise UnsupportedPatternError(
|
|
3305
|
+
f"isinstance() with non-exception class '{class_name}' is not supported in workflows",
|
|
3306
|
+
"isinstance() can only be used to check exception types in workflow code. "
|
|
3307
|
+
"Move other type checks to an @action.",
|
|
3308
|
+
line=line,
|
|
3309
|
+
col=col,
|
|
3310
|
+
)
|
|
3311
|
+
exception_names.append(class_name)
|
|
3312
|
+
elif isinstance(second_arg, ast.Tuple):
|
|
3313
|
+
# Tuple of classes: isinstance(x, (ValueError, TypeError))
|
|
3314
|
+
for elt in second_arg.elts:
|
|
3315
|
+
if not isinstance(elt, ast.Name):
|
|
3316
|
+
line = expr.lineno if hasattr(expr, "lineno") else None
|
|
3317
|
+
col = expr.col_offset if hasattr(expr, "col_offset") else None
|
|
3318
|
+
raise UnsupportedPatternError(
|
|
3319
|
+
"isinstance() class argument must be a simple name or tuple of names",
|
|
3320
|
+
"Use simple class names like ValueError or (ValueError, TypeError).",
|
|
3321
|
+
line=line,
|
|
3322
|
+
col=col,
|
|
3323
|
+
)
|
|
3324
|
+
class_name = elt.id
|
|
3325
|
+
if not exception_class_resolver(class_name):
|
|
3326
|
+
line = expr.lineno if hasattr(expr, "lineno") else None
|
|
3327
|
+
col = expr.col_offset if hasattr(expr, "col_offset") else None
|
|
3328
|
+
raise UnsupportedPatternError(
|
|
3329
|
+
f"isinstance() with non-exception class '{class_name}' is not supported in workflows",
|
|
3330
|
+
"isinstance() can only be used to check exception types in workflow code. "
|
|
3331
|
+
"Move other type checks to an @action.",
|
|
3332
|
+
line=line,
|
|
3333
|
+
col=col,
|
|
3334
|
+
)
|
|
3335
|
+
exception_names.append(class_name)
|
|
3336
|
+
else:
|
|
3337
|
+
line = expr.lineno if hasattr(expr, "lineno") else None
|
|
3338
|
+
col = expr.col_offset if hasattr(expr, "col_offset") else None
|
|
3339
|
+
raise UnsupportedPatternError(
|
|
3340
|
+
"isinstance() class argument must be a simple name or tuple of names",
|
|
3341
|
+
"Use simple class names like ValueError or (ValueError, TypeError).",
|
|
3342
|
+
line=line,
|
|
3343
|
+
col=col,
|
|
3344
|
+
)
|
|
3345
|
+
|
|
3346
|
+
# Convert the first argument (the value being checked) to IR
|
|
3347
|
+
value_expr = _expr_to_ir(
|
|
3348
|
+
expr.args[0],
|
|
3349
|
+
model_converter=model_converter,
|
|
3350
|
+
enum_resolver=enum_resolver,
|
|
3351
|
+
exception_class_resolver=exception_class_resolver,
|
|
3352
|
+
)
|
|
3353
|
+
if not value_expr:
|
|
3354
|
+
return None
|
|
3355
|
+
|
|
3356
|
+
# Build the isexception call
|
|
3357
|
+
# If single exception: isexception(x, "ValueError")
|
|
3358
|
+
# If multiple: isexception(x, ["ValueError", "TypeError"])
|
|
2874
3359
|
result = ir.Expr(span=_make_span(expr))
|
|
2875
3360
|
|
|
3361
|
+
if len(exception_names) == 1:
|
|
3362
|
+
# Single exception name as string
|
|
3363
|
+
type_arg = ir.Expr(span=_make_span(second_arg))
|
|
3364
|
+
type_arg.literal.CopyFrom(ir.Literal(string_value=exception_names[0]))
|
|
3365
|
+
args = [value_expr, type_arg]
|
|
3366
|
+
else:
|
|
3367
|
+
# Multiple exception names as list of strings
|
|
3368
|
+
type_elements = []
|
|
3369
|
+
for name in exception_names:
|
|
3370
|
+
elem = ir.Expr()
|
|
3371
|
+
elem.literal.CopyFrom(ir.Literal(string_value=name))
|
|
3372
|
+
type_elements.append(elem)
|
|
3373
|
+
type_arg = ir.Expr(span=_make_span(second_arg))
|
|
3374
|
+
type_arg.list.CopyFrom(ir.ListExpr(elements=type_elements))
|
|
3375
|
+
args = [value_expr, type_arg]
|
|
3376
|
+
|
|
3377
|
+
func_call = ir.FunctionCall(
|
|
3378
|
+
name="isexception",
|
|
3379
|
+
args=args,
|
|
3380
|
+
kwargs=[],
|
|
3381
|
+
)
|
|
3382
|
+
func_call.global_function = ir.GlobalFunction.GLOBAL_FUNCTION_ISEXCEPTION
|
|
3383
|
+
result.function_call.CopyFrom(func_call)
|
|
3384
|
+
return result
|
|
3385
|
+
|
|
3386
|
+
|
|
3387
|
+
def _expr_to_ir(
|
|
3388
|
+
expr: ast.AST,
|
|
3389
|
+
model_converter: Optional[Callable[[ast.Call], Optional[ir.Expr]]] = None,
|
|
3390
|
+
enum_resolver: Optional[Callable[[ast.Attribute], Optional[ir.Expr]]] = None,
|
|
3391
|
+
exception_class_resolver: Optional[Callable[[str], bool]] = None,
|
|
3392
|
+
) -> Optional[ir.Expr]:
|
|
3393
|
+
"""Convert Python AST expression to IR Expr.
|
|
3394
|
+
|
|
3395
|
+
Args:
|
|
3396
|
+
expr: The AST expression to convert.
|
|
3397
|
+
model_converter: Optional callback to convert model constructors.
|
|
3398
|
+
enum_resolver: Optional callback to resolve enum attributes.
|
|
3399
|
+
exception_class_resolver: Optional callback that takes a class name and
|
|
3400
|
+
returns True if it's an exception class. Used to transform
|
|
3401
|
+
isinstance(x, ExceptionClass) to isexception(x, "ExceptionClass").
|
|
3402
|
+
"""
|
|
3403
|
+
result = ir.Expr(span=_make_span(expr))
|
|
3404
|
+
|
|
3405
|
+
if isinstance(expr, ast.Call) and model_converter:
|
|
3406
|
+
converted = model_converter(expr)
|
|
3407
|
+
if converted:
|
|
3408
|
+
return converted
|
|
3409
|
+
|
|
3410
|
+
# Handle isinstance(x, ExceptionClass) -> isexception(x, "ExceptionClass")
|
|
3411
|
+
if isinstance(expr, ast.Call) and exception_class_resolver:
|
|
3412
|
+
isinstance_result = _try_convert_isinstance_to_isexception(
|
|
3413
|
+
expr, exception_class_resolver, model_converter, enum_resolver
|
|
3414
|
+
)
|
|
3415
|
+
if isinstance_result is not None:
|
|
3416
|
+
return isinstance_result
|
|
3417
|
+
|
|
2876
3418
|
if isinstance(expr, ast.Name):
|
|
2877
3419
|
result.variable.CopyFrom(ir.Variable(name=expr.id))
|
|
2878
3420
|
return result
|
|
@@ -2884,34 +3426,66 @@ def _expr_to_ir(expr: ast.AST) -> Optional[ir.Expr]:
|
|
|
2884
3426
|
return result
|
|
2885
3427
|
|
|
2886
3428
|
if isinstance(expr, ast.BinOp):
|
|
2887
|
-
left = _expr_to_ir(
|
|
2888
|
-
|
|
3429
|
+
left = _expr_to_ir(
|
|
3430
|
+
expr.left,
|
|
3431
|
+
model_converter=model_converter,
|
|
3432
|
+
enum_resolver=enum_resolver,
|
|
3433
|
+
exception_class_resolver=exception_class_resolver,
|
|
3434
|
+
)
|
|
3435
|
+
right = _expr_to_ir(
|
|
3436
|
+
expr.right,
|
|
3437
|
+
model_converter=model_converter,
|
|
3438
|
+
enum_resolver=enum_resolver,
|
|
3439
|
+
exception_class_resolver=exception_class_resolver,
|
|
3440
|
+
)
|
|
2889
3441
|
op = _bin_op_to_ir(expr.op)
|
|
2890
3442
|
if left and right and op:
|
|
2891
3443
|
result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
|
|
2892
3444
|
return result
|
|
2893
3445
|
|
|
2894
3446
|
if isinstance(expr, ast.UnaryOp):
|
|
2895
|
-
operand = _expr_to_ir(
|
|
3447
|
+
operand = _expr_to_ir(
|
|
3448
|
+
expr.operand,
|
|
3449
|
+
model_converter=model_converter,
|
|
3450
|
+
enum_resolver=enum_resolver,
|
|
3451
|
+
exception_class_resolver=exception_class_resolver,
|
|
3452
|
+
)
|
|
2896
3453
|
op = _unary_op_to_ir(expr.op)
|
|
2897
3454
|
if operand and op:
|
|
2898
3455
|
result.unary_op.CopyFrom(ir.UnaryOp(op=op, operand=operand))
|
|
2899
3456
|
return result
|
|
2900
3457
|
|
|
2901
3458
|
if isinstance(expr, ast.Compare):
|
|
2902
|
-
left = _expr_to_ir(
|
|
3459
|
+
left = _expr_to_ir(
|
|
3460
|
+
expr.left,
|
|
3461
|
+
model_converter=model_converter,
|
|
3462
|
+
enum_resolver=enum_resolver,
|
|
3463
|
+
exception_class_resolver=exception_class_resolver,
|
|
3464
|
+
)
|
|
2903
3465
|
if not left:
|
|
2904
3466
|
return None
|
|
2905
3467
|
# For simplicity, handle single comparison
|
|
2906
3468
|
if expr.ops and expr.comparators:
|
|
2907
3469
|
op = _cmp_op_to_ir(expr.ops[0])
|
|
2908
|
-
right = _expr_to_ir(
|
|
3470
|
+
right = _expr_to_ir(
|
|
3471
|
+
expr.comparators[0],
|
|
3472
|
+
model_converter=model_converter,
|
|
3473
|
+
enum_resolver=enum_resolver,
|
|
3474
|
+
)
|
|
2909
3475
|
if op and right:
|
|
2910
3476
|
result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
|
|
2911
3477
|
return result
|
|
2912
3478
|
|
|
2913
3479
|
if isinstance(expr, ast.BoolOp):
|
|
2914
|
-
values = [
|
|
3480
|
+
values = [
|
|
3481
|
+
_expr_to_ir(
|
|
3482
|
+
v,
|
|
3483
|
+
model_converter=model_converter,
|
|
3484
|
+
enum_resolver=enum_resolver,
|
|
3485
|
+
exception_class_resolver=exception_class_resolver,
|
|
3486
|
+
)
|
|
3487
|
+
for v in expr.values
|
|
3488
|
+
]
|
|
2915
3489
|
if all(v for v in values):
|
|
2916
3490
|
op = _bool_op_to_ir(expr.op)
|
|
2917
3491
|
if op and len(values) >= 2:
|
|
@@ -2925,7 +3499,15 @@ def _expr_to_ir(expr: ast.AST) -> Optional[ir.Expr]:
|
|
|
2925
3499
|
return result_expr
|
|
2926
3500
|
|
|
2927
3501
|
if isinstance(expr, ast.List):
|
|
2928
|
-
elements = [
|
|
3502
|
+
elements = [
|
|
3503
|
+
_expr_to_ir(
|
|
3504
|
+
e,
|
|
3505
|
+
model_converter=model_converter,
|
|
3506
|
+
enum_resolver=enum_resolver,
|
|
3507
|
+
exception_class_resolver=exception_class_resolver,
|
|
3508
|
+
)
|
|
3509
|
+
for e in expr.elts
|
|
3510
|
+
]
|
|
2929
3511
|
if all(e for e in elements):
|
|
2930
3512
|
list_expr = ir.ListExpr(elements=[e for e in elements if e])
|
|
2931
3513
|
result.list.CopyFrom(list_expr)
|
|
@@ -2935,35 +3517,128 @@ def _expr_to_ir(expr: ast.AST) -> Optional[ir.Expr]:
|
|
|
2935
3517
|
entries: List[ir.DictEntry] = []
|
|
2936
3518
|
for k, v in zip(expr.keys, expr.values, strict=False):
|
|
2937
3519
|
if k:
|
|
2938
|
-
key_expr = _expr_to_ir(
|
|
2939
|
-
|
|
3520
|
+
key_expr = _expr_to_ir(
|
|
3521
|
+
k,
|
|
3522
|
+
model_converter=model_converter,
|
|
3523
|
+
enum_resolver=enum_resolver,
|
|
3524
|
+
)
|
|
3525
|
+
value_expr = _expr_to_ir(
|
|
3526
|
+
v,
|
|
3527
|
+
model_converter=model_converter,
|
|
3528
|
+
enum_resolver=enum_resolver,
|
|
3529
|
+
)
|
|
2940
3530
|
if key_expr and value_expr:
|
|
2941
3531
|
entries.append(ir.DictEntry(key=key_expr, value=value_expr))
|
|
2942
3532
|
result.dict.CopyFrom(ir.DictExpr(entries=entries))
|
|
2943
3533
|
return result
|
|
2944
3534
|
|
|
2945
3535
|
if isinstance(expr, ast.Subscript):
|
|
2946
|
-
obj = _expr_to_ir(
|
|
2947
|
-
|
|
3536
|
+
obj = _expr_to_ir(
|
|
3537
|
+
expr.value,
|
|
3538
|
+
model_converter=model_converter,
|
|
3539
|
+
enum_resolver=enum_resolver,
|
|
3540
|
+
exception_class_resolver=exception_class_resolver,
|
|
3541
|
+
)
|
|
3542
|
+
index = (
|
|
3543
|
+
_expr_to_ir(
|
|
3544
|
+
expr.slice,
|
|
3545
|
+
model_converter=model_converter,
|
|
3546
|
+
enum_resolver=enum_resolver,
|
|
3547
|
+
)
|
|
3548
|
+
if isinstance(expr.slice, ast.AST)
|
|
3549
|
+
else None
|
|
3550
|
+
)
|
|
2948
3551
|
if obj and index:
|
|
2949
3552
|
result.index.CopyFrom(ir.IndexAccess(object=obj, index=index))
|
|
2950
3553
|
return result
|
|
2951
3554
|
|
|
2952
3555
|
if isinstance(expr, ast.Attribute):
|
|
2953
|
-
|
|
3556
|
+
if enum_resolver:
|
|
3557
|
+
resolved = enum_resolver(expr)
|
|
3558
|
+
if resolved:
|
|
3559
|
+
return resolved
|
|
3560
|
+
obj = _expr_to_ir(
|
|
3561
|
+
expr.value,
|
|
3562
|
+
model_converter=model_converter,
|
|
3563
|
+
enum_resolver=enum_resolver,
|
|
3564
|
+
exception_class_resolver=exception_class_resolver,
|
|
3565
|
+
)
|
|
2954
3566
|
if obj:
|
|
2955
3567
|
result.dot.CopyFrom(ir.DotAccess(object=obj, attribute=expr.attr))
|
|
2956
3568
|
return result
|
|
2957
3569
|
|
|
3570
|
+
if isinstance(expr, ast.Await) and isinstance(expr.value, ast.Call):
|
|
3571
|
+
func_name = _get_func_name(expr.value.func)
|
|
3572
|
+
if func_name:
|
|
3573
|
+
args = [
|
|
3574
|
+
_expr_to_ir(
|
|
3575
|
+
a,
|
|
3576
|
+
model_converter=model_converter,
|
|
3577
|
+
enum_resolver=enum_resolver,
|
|
3578
|
+
)
|
|
3579
|
+
for a in expr.value.args
|
|
3580
|
+
]
|
|
3581
|
+
kwargs: List[ir.Kwarg] = []
|
|
3582
|
+
for kw in expr.value.keywords:
|
|
3583
|
+
if kw.arg:
|
|
3584
|
+
kw_expr = _expr_to_ir(
|
|
3585
|
+
kw.value,
|
|
3586
|
+
model_converter=model_converter,
|
|
3587
|
+
enum_resolver=enum_resolver,
|
|
3588
|
+
)
|
|
3589
|
+
if kw_expr:
|
|
3590
|
+
kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
|
|
3591
|
+
func_call = ir.FunctionCall(
|
|
3592
|
+
name=func_name,
|
|
3593
|
+
args=[a for a in args if a],
|
|
3594
|
+
kwargs=kwargs,
|
|
3595
|
+
)
|
|
3596
|
+
global_function = _global_function_for_call(func_name, expr.value)
|
|
3597
|
+
if global_function is not None:
|
|
3598
|
+
func_call.global_function = global_function
|
|
3599
|
+
result.function_call.CopyFrom(func_call)
|
|
3600
|
+
return result
|
|
3601
|
+
|
|
2958
3602
|
if isinstance(expr, ast.Call):
|
|
2959
3603
|
# Function call
|
|
3604
|
+
if not _is_self_method_call(expr):
|
|
3605
|
+
func_name = _get_func_name(expr.func) or "unknown"
|
|
3606
|
+
if isinstance(expr.func, ast.Attribute):
|
|
3607
|
+
line = expr.lineno if hasattr(expr, "lineno") else None
|
|
3608
|
+
col = expr.col_offset if hasattr(expr, "col_offset") else None
|
|
3609
|
+
raise UnsupportedPatternError(
|
|
3610
|
+
f"Calling synchronous function '{func_name}()' directly is not supported",
|
|
3611
|
+
RECOMMENDATIONS["sync_function_call"],
|
|
3612
|
+
line=line,
|
|
3613
|
+
col=col,
|
|
3614
|
+
)
|
|
3615
|
+
if func_name not in ALLOWED_SYNC_FUNCTIONS:
|
|
3616
|
+
line = expr.lineno if hasattr(expr, "lineno") else None
|
|
3617
|
+
col = expr.col_offset if hasattr(expr, "col_offset") else None
|
|
3618
|
+
raise UnsupportedPatternError(
|
|
3619
|
+
f"Calling synchronous function '{func_name}()' directly is not supported",
|
|
3620
|
+
RECOMMENDATIONS["sync_function_call"],
|
|
3621
|
+
line=line,
|
|
3622
|
+
col=col,
|
|
3623
|
+
)
|
|
2960
3624
|
func_name = _get_func_name(expr.func)
|
|
2961
3625
|
if func_name:
|
|
2962
|
-
args = [
|
|
3626
|
+
args = [
|
|
3627
|
+
_expr_to_ir(
|
|
3628
|
+
a,
|
|
3629
|
+
model_converter=model_converter,
|
|
3630
|
+
enum_resolver=enum_resolver,
|
|
3631
|
+
)
|
|
3632
|
+
for a in expr.args
|
|
3633
|
+
]
|
|
2963
3634
|
kwargs: List[ir.Kwarg] = []
|
|
2964
3635
|
for kw in expr.keywords:
|
|
2965
3636
|
if kw.arg:
|
|
2966
|
-
kw_expr = _expr_to_ir(
|
|
3637
|
+
kw_expr = _expr_to_ir(
|
|
3638
|
+
kw.value,
|
|
3639
|
+
model_converter=model_converter,
|
|
3640
|
+
enum_resolver=enum_resolver,
|
|
3641
|
+
)
|
|
2967
3642
|
if kw_expr:
|
|
2968
3643
|
kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
|
|
2969
3644
|
func_call = ir.FunctionCall(
|
|
@@ -2971,12 +3646,23 @@ def _expr_to_ir(expr: ast.AST) -> Optional[ir.Expr]:
|
|
|
2971
3646
|
args=[a for a in args if a],
|
|
2972
3647
|
kwargs=kwargs,
|
|
2973
3648
|
)
|
|
3649
|
+
global_function = _global_function_for_call(func_name, expr)
|
|
3650
|
+
if global_function is not None:
|
|
3651
|
+
func_call.global_function = global_function
|
|
2974
3652
|
result.function_call.CopyFrom(func_call)
|
|
2975
3653
|
return result
|
|
2976
3654
|
|
|
2977
3655
|
if isinstance(expr, ast.Tuple):
|
|
2978
3656
|
# Handle tuple as list for now
|
|
2979
|
-
elements = [
|
|
3657
|
+
elements = [
|
|
3658
|
+
_expr_to_ir(
|
|
3659
|
+
e,
|
|
3660
|
+
model_converter=model_converter,
|
|
3661
|
+
enum_resolver=enum_resolver,
|
|
3662
|
+
exception_class_resolver=exception_class_resolver,
|
|
3663
|
+
)
|
|
3664
|
+
for e in expr.elts
|
|
3665
|
+
]
|
|
2980
3666
|
if all(e for e in elements):
|
|
2981
3667
|
list_expr = ir.ListExpr(elements=[e for e in elements if e])
|
|
2982
3668
|
result.list.CopyFrom(list_expr)
|
|
@@ -2993,6 +3679,15 @@ def _check_unsupported_expression(expr: ast.AST) -> None:
|
|
|
2993
3679
|
line = getattr(expr, "lineno", None)
|
|
2994
3680
|
col = getattr(expr, "col_offset", None)
|
|
2995
3681
|
|
|
3682
|
+
if isinstance(expr, ast.Constant):
|
|
3683
|
+
if _constant_to_literal(expr.value) is None:
|
|
3684
|
+
raise UnsupportedPatternError(
|
|
3685
|
+
f"Unsupported literal type '{type(expr.value).__name__}'",
|
|
3686
|
+
RECOMMENDATIONS["unsupported_literal"],
|
|
3687
|
+
line=line,
|
|
3688
|
+
col=col,
|
|
3689
|
+
)
|
|
3690
|
+
|
|
2996
3691
|
if isinstance(expr, ast.JoinedStr):
|
|
2997
3692
|
raise UnsupportedPatternError(
|
|
2998
3693
|
"F-strings are not supported",
|
|
@@ -3049,6 +3744,13 @@ def _check_unsupported_expression(expr: ast.AST) -> None:
|
|
|
3049
3744
|
line=line,
|
|
3050
3745
|
col=col,
|
|
3051
3746
|
)
|
|
3747
|
+
elif isinstance(expr, ast.expr):
|
|
3748
|
+
raise UnsupportedPatternError(
|
|
3749
|
+
f"Unsupported expression type '{type(expr).__name__}'",
|
|
3750
|
+
RECOMMENDATIONS["unsupported_expression"],
|
|
3751
|
+
line=line,
|
|
3752
|
+
col=col,
|
|
3753
|
+
)
|
|
3052
3754
|
|
|
3053
3755
|
|
|
3054
3756
|
def _format_subscript_target(target: ast.Subscript) -> Optional[str]:
|
|
@@ -3142,5 +3844,27 @@ def _get_func_name(func: ast.expr) -> Optional[str]:
|
|
|
3142
3844
|
current = current.value
|
|
3143
3845
|
if isinstance(current, ast.Name):
|
|
3144
3846
|
parts.append(current.id)
|
|
3145
|
-
|
|
3847
|
+
name = ".".join(reversed(parts))
|
|
3848
|
+
if name.startswith("self."):
|
|
3849
|
+
return name[5:]
|
|
3850
|
+
return name
|
|
3146
3851
|
return None
|
|
3852
|
+
|
|
3853
|
+
|
|
3854
|
+
def _is_self_method_call(node: ast.Call) -> bool:
|
|
3855
|
+
"""Return True if the call is a direct self.method(...) invocation."""
|
|
3856
|
+
func = node.func
|
|
3857
|
+
return (
|
|
3858
|
+
isinstance(func, ast.Attribute)
|
|
3859
|
+
and isinstance(func.value, ast.Name)
|
|
3860
|
+
and func.value.id == "self"
|
|
3861
|
+
)
|
|
3862
|
+
|
|
3863
|
+
|
|
3864
|
+
def _global_function_for_call(
|
|
3865
|
+
func_name: str, node: ast.Call
|
|
3866
|
+
) -> Optional[ir.GlobalFunction.ValueType]:
|
|
3867
|
+
"""Return the GlobalFunction enum value for supported globals."""
|
|
3868
|
+
if _is_self_method_call(node):
|
|
3869
|
+
return None
|
|
3870
|
+
return GLOBAL_FUNCTIONS.get(func_name)
|