rappel 0.5.1__py3-none-manylinux_2_39_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rappel/ir_builder.py ADDED
@@ -0,0 +1,2966 @@
1
+ """
2
+ IR Builder - Converts Python workflow AST to Rappel IR (ast.proto).
3
+
4
+ This module parses Python workflow classes and produces the IR representation
5
+ that can be sent to the Rust runtime for execution.
6
+
7
+ The IR builder performs deep transformations to convert Python patterns into
8
+ valid Rappel IR structures. Control flow bodies (try, for, if branches) are
9
+ represented as full blocks, so they can contain arbitrary statements and
10
+ `return` behaves consistently with Python (returns from `run()` end the
11
+ workflow).
12
+
13
+ Validation:
14
+ The IR builder proactively detects unsupported Python patterns and raises
15
+ UnsupportedPatternError with clear recommendations for how to rewrite the code.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ import copy
22
+ import inspect
23
+ import textwrap
24
+ from dataclasses import dataclass
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
26
+
27
+ from proto import ast_pb2 as ir
28
+
29
+
30
+ class UnsupportedPatternError(Exception):
31
+ """Raised when the IR builder encounters an unsupported Python pattern.
32
+
33
+ This error includes a recommendation for how to rewrite the code to use
34
+ supported patterns.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ message: str,
40
+ recommendation: str,
41
+ line: Optional[int] = None,
42
+ col: Optional[int] = None,
43
+ ):
44
+ self.message = message
45
+ self.recommendation = recommendation
46
+ self.line = line
47
+ self.col = col
48
+
49
+ location = f" (line {line})" if line else ""
50
+ full_message = f"{message}{location}\n\nRecommendation: {recommendation}"
51
+ super().__init__(full_message)
52
+
53
+
54
+ # Recommendations for common unsupported patterns
55
+ RECOMMENDATIONS = {
56
+ "constructor_return": (
57
+ "Returning a class constructor (like MyModel(...)) directly is not supported.\n"
58
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
59
+ "Use an @action to create the object:\n\n"
60
+ " @action\n"
61
+ " async def build_result(items: list, count: int) -> MyResult:\n"
62
+ " return MyResult(items=items, count=count)\n\n"
63
+ " # In workflow:\n"
64
+ " return await build_result(items, count)"
65
+ ),
66
+ "constructor_assignment": (
67
+ "Assigning a class constructor result (like x = MyClass(...)) is not supported.\n"
68
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
69
+ "Use an @action to create the object:\n\n"
70
+ " @action\n"
71
+ " async def create_config(value: int) -> Config:\n"
72
+ " return Config(value=value)\n\n"
73
+ " # In workflow:\n"
74
+ " config = await create_config(value)"
75
+ ),
76
+ "non_action_call": (
77
+ "Calling a function that is not decorated with @action is not supported.\n"
78
+ "Only @action decorated functions can be awaited in workflow code.\n\n"
79
+ "Add the @action decorator to your function:\n\n"
80
+ " @action\n"
81
+ " async def my_function(x: int) -> int:\n"
82
+ " return x * 2"
83
+ ),
84
+ "sync_function_call": (
85
+ "Calling a synchronous function directly in workflow code is not supported.\n"
86
+ "All computation must happen inside @action decorated async functions.\n\n"
87
+ "Wrap your logic in an @action:\n\n"
88
+ " @action\n"
89
+ " async def compute(x: int) -> int:\n"
90
+ " return some_sync_function(x)"
91
+ ),
92
+ "method_call_non_self": (
93
+ "Calling methods on objects other than 'self' is not supported in workflow code.\n"
94
+ "Use an @action to perform method calls:\n\n"
95
+ " @action\n"
96
+ " async def call_method(obj: MyClass) -> Result:\n"
97
+ " return obj.some_method()"
98
+ ),
99
+ "builtin_call": (
100
+ "Calling built-in functions like len(), str(), int() directly is not supported.\n"
101
+ "Use an @action to perform these operations:\n\n"
102
+ " @action\n"
103
+ " async def get_length(items: list) -> int:\n"
104
+ " return len(items)"
105
+ ),
106
+ "fstring": (
107
+ "F-strings are not supported in workflow code because they require "
108
+ "runtime string interpolation.\n"
109
+ "Use an @action to perform string formatting:\n\n"
110
+ " @action\n"
111
+ " async def format_message(value: int) -> str:\n"
112
+ " return f'Result: {value}'"
113
+ ),
114
+ "delete": (
115
+ "The 'del' statement is not supported in workflow code.\n"
116
+ "Use an @action to perform mutations:\n\n"
117
+ " @action\n"
118
+ " async def remove_key(data: dict, key: str) -> dict:\n"
119
+ " del data[key]\n"
120
+ " return data"
121
+ ),
122
+ "while_loop": (
123
+ "While loops are not supported in workflow code because they can run "
124
+ "indefinitely.\n"
125
+ "Use a for loop with a fixed range, or restructure as recursive workflow calls."
126
+ ),
127
+ "with_statement": (
128
+ "Context managers (with statements) are not supported in workflow code.\n"
129
+ "Use an @action to handle resource management:\n\n"
130
+ " @action\n"
131
+ " async def read_file(path: str) -> str:\n"
132
+ " with open(path) as f:\n"
133
+ " return f.read()"
134
+ ),
135
+ "raise_statement": (
136
+ "The 'raise' statement is not supported directly in workflow code.\n"
137
+ "Use an @action that raises exceptions, or return error values."
138
+ ),
139
+ "assert_statement": (
140
+ "Assert statements are not supported in workflow code.\n"
141
+ "Use an @action for validation, or use if statements with explicit error handling."
142
+ ),
143
+ "lambda": (
144
+ "Lambda expressions are not supported in workflow code.\n"
145
+ "Use an @action to define the function logic."
146
+ ),
147
+ "list_comprehension": (
148
+ "List comprehensions are only supported when assigned directly to a variable "
149
+ "or inside asyncio.gather(*[...]).\n"
150
+ "For other cases, use a for loop or an @action."
151
+ ),
152
+ "dict_comprehension": (
153
+ "Dict comprehensions are only supported when assigned directly to a variable.\n"
154
+ "For other cases, use a for loop or an @action:\n\n"
155
+ " @action\n"
156
+ " async def build_dict(items: list) -> dict:\n"
157
+ " return {k: v for k, v in items}"
158
+ ),
159
+ "set_comprehension": (
160
+ "Set comprehensions are not supported in workflow code.\nUse an @action to build sets."
161
+ ),
162
+ "generator": (
163
+ "Generator expressions are not supported in workflow code.\n"
164
+ "Use a list or an @action instead."
165
+ ),
166
+ "walrus": (
167
+ "The walrus operator (:=) is not supported in workflow code.\n"
168
+ "Use separate assignment statements instead."
169
+ ),
170
+ "match": (
171
+ "Match statements are not supported in workflow code.\nUse if/elif/else chains instead."
172
+ ),
173
+ "gather_variable_spread": (
174
+ "Spreading a variable in asyncio.gather() is not supported because it requires "
175
+ "data flow analysis to determine the contents.\n"
176
+ "Use a list comprehension directly in gather:\n\n"
177
+ " # Instead of:\n"
178
+ " tasks = []\n"
179
+ " for i in range(count):\n"
180
+ " tasks.append(process(value=i))\n"
181
+ " results = await asyncio.gather(*tasks)\n\n"
182
+ " # Use:\n"
183
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
184
+ ),
185
+ "for_loop_append_pattern": (
186
+ "Building a task list in a for loop then spreading in asyncio.gather() is not "
187
+ "supported.\n"
188
+ "Use a list comprehension directly in gather:\n\n"
189
+ " # Instead of:\n"
190
+ " tasks = []\n"
191
+ " for i in range(count):\n"
192
+ " tasks.append(process(value=i))\n"
193
+ " results = await asyncio.gather(*tasks)\n\n"
194
+ " # Use:\n"
195
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
196
+ ),
197
+ "global_statement": (
198
+ "Global statements are not supported in workflow code.\n"
199
+ "Use workflow state or pass values explicitly."
200
+ ),
201
+ "nonlocal_statement": (
202
+ "Nonlocal statements are not supported in workflow code.\n"
203
+ "Use explicit parameter passing instead."
204
+ ),
205
+ "import_statement": (
206
+ "Import statements inside workflow run() are not supported.\n"
207
+ "Place imports at the module level."
208
+ ),
209
+ "class_def": (
210
+ "Class definitions inside workflow run() are not supported.\n"
211
+ "Define classes at the module level."
212
+ ),
213
+ "function_def": (
214
+ "Nested function definitions inside workflow run() are not supported.\n"
215
+ "Define functions at the module level or use @action."
216
+ ),
217
+ "yield_statement": (
218
+ "Yield statements are not supported in workflow code.\n"
219
+ "Workflows must return a complete result, not generate values incrementally."
220
+ ),
221
+ }
222
+
223
+
224
+ if TYPE_CHECKING:
225
+ from .workflow import Workflow
226
+
227
+
228
+ @dataclass
229
+ class ActionDefinition:
230
+ """Definition of an action function."""
231
+
232
+ action_name: str
233
+ module_name: Optional[str]
234
+ signature: inspect.Signature
235
+
236
+
237
+ @dataclass
238
+ class TransformContext:
239
+ """Context for IR transformations."""
240
+
241
+ # Counter for generating unique function names
242
+ implicit_fn_counter: int = 0
243
+ # Implicit functions generated during transformation
244
+ implicit_functions: List[ir.FunctionDef] = None # type: ignore
245
+
246
+ def __post_init__(self) -> None:
247
+ if self.implicit_functions is None:
248
+ self.implicit_functions = []
249
+
250
+ def next_implicit_fn_name(self, prefix: str = "implicit") -> str:
251
+ """Generate a unique implicit function name."""
252
+ self.implicit_fn_counter += 1
253
+ return f"__{prefix}_{self.implicit_fn_counter}__"
254
+
255
+
256
+ def build_workflow_ir(workflow_cls: type["Workflow"]) -> ir.Program:
257
+ """Build an IR Program from a workflow class.
258
+
259
+ Args:
260
+ workflow_cls: The workflow class to convert.
261
+
262
+ Returns:
263
+ An IR Program proto message.
264
+ """
265
+ original_run = getattr(workflow_cls, "__workflow_run_impl__", None)
266
+ if original_run is None:
267
+ original_run = workflow_cls.__dict__.get("run")
268
+ if original_run is None:
269
+ raise ValueError(f"workflow {workflow_cls!r} missing run implementation")
270
+
271
+ module = inspect.getmodule(original_run)
272
+ if module is None:
273
+ raise ValueError(f"unable to locate module for workflow {workflow_cls!r}")
274
+
275
+ # Get the function source and parse it
276
+ function_source = textwrap.dedent(inspect.getsource(original_run))
277
+ tree = ast.parse(function_source)
278
+
279
+ # Discover actions in the module
280
+ action_defs = _discover_action_names(module)
281
+
282
+ # Discover imports for built-in detection (e.g., from asyncio import sleep)
283
+ imported_names = _discover_module_imports(module)
284
+
285
+ # Discover all async function names in the module (for non-action detection)
286
+ module_functions = _discover_module_functions(module)
287
+
288
+ # Discover Pydantic models and dataclasses that can be used in workflows
289
+ model_defs = _discover_model_definitions(module)
290
+
291
+ # Build the IR with transformation context
292
+ ctx = TransformContext()
293
+ builder = IRBuilder(action_defs, ctx, imported_names, module_functions, model_defs)
294
+ builder.visit(tree)
295
+
296
+ # Create the Program with the main function and any implicit functions
297
+ program = ir.Program()
298
+
299
+ # Add implicit functions first (they may be called by the main function)
300
+ for implicit_fn in ctx.implicit_functions:
301
+ program.functions.append(implicit_fn)
302
+
303
+ # Add the main function
304
+ if builder.function_def:
305
+ program.functions.append(builder.function_def)
306
+
307
+ return program
308
+
309
+
310
+ def _discover_action_names(module: Any) -> Dict[str, ActionDefinition]:
311
+ """Discover all @action decorated functions in a module."""
312
+ names: Dict[str, ActionDefinition] = {}
313
+ for attr_name in dir(module):
314
+ attr = getattr(module, attr_name)
315
+ action_name = getattr(attr, "__rappel_action_name__", None)
316
+ action_module = getattr(attr, "__rappel_action_module__", None)
317
+ if callable(attr) and action_name:
318
+ signature = inspect.signature(attr)
319
+ names[attr_name] = ActionDefinition(
320
+ action_name=action_name,
321
+ module_name=action_module or module.__name__,
322
+ signature=signature,
323
+ )
324
+ return names
325
+
326
+
327
+ def _discover_module_functions(module: Any) -> Set[str]:
328
+ """Discover all async function names defined in a module.
329
+
330
+ This is used to detect when users await functions in the same module
331
+ that are NOT decorated with @action.
332
+ """
333
+ function_names: Set[str] = set()
334
+ for attr_name in dir(module):
335
+ try:
336
+ attr = getattr(module, attr_name)
337
+ except AttributeError:
338
+ continue
339
+
340
+ # Only include functions defined in THIS module (not imported)
341
+ if not callable(attr):
342
+ continue
343
+ if not inspect.iscoroutinefunction(attr):
344
+ continue
345
+
346
+ # Check if the function is defined in this module
347
+ func_module = getattr(attr, "__module__", None)
348
+ if func_module == module.__name__:
349
+ function_names.add(attr_name)
350
+
351
+ return function_names
352
+
353
+
354
+ @dataclass
355
+ class ImportedName:
356
+ """Tracks an imported name and its source module."""
357
+
358
+ local_name: str # Name used in code (e.g., "sleep")
359
+ module: str # Source module (e.g., "asyncio")
360
+ original_name: str # Original name in source module (e.g., "sleep")
361
+
362
+
363
+ @dataclass
364
+ class ModelFieldDefinition:
365
+ """Definition of a field in a Pydantic model or dataclass."""
366
+
367
+ name: str
368
+ has_default: bool
369
+ default_value: Any = None # Only set if has_default is True
370
+
371
+
372
+ @dataclass
373
+ class ModelDefinition:
374
+ """Definition of a Pydantic model or dataclass that can be used in workflows.
375
+
376
+ These are data classes that can be instantiated in workflow code and will
377
+ be converted to dictionary expressions in the IR.
378
+ """
379
+
380
+ class_name: str
381
+ fields: Dict[str, ModelFieldDefinition]
382
+ is_pydantic: bool # True for Pydantic models, False for dataclasses
383
+
384
+
385
+ def _is_simple_pydantic_model(cls: type) -> bool:
386
+ """Check if a class is a simple Pydantic model without custom validators.
387
+
388
+ A simple Pydantic model:
389
+ - Inherits from pydantic.BaseModel
390
+ - Has no field_validator or model_validator decorators
391
+ - Has no custom __init__ method
392
+
393
+ Returns False if pydantic is not installed or cls is not a Pydantic model.
394
+ """
395
+ try:
396
+ from pydantic import BaseModel
397
+ except ImportError:
398
+ return False
399
+
400
+ if not isinstance(cls, type) or not issubclass(cls, BaseModel):
401
+ return False
402
+
403
+ # Check for validators - Pydantic v2 uses __pydantic_decorators__
404
+ decorators = getattr(cls, "__pydantic_decorators__", None)
405
+ if decorators is not None:
406
+ # Check for field validators
407
+ if hasattr(decorators, "field_validators") and decorators.field_validators:
408
+ return False
409
+ # Check for model validators
410
+ if hasattr(decorators, "model_validators") and decorators.model_validators:
411
+ return False
412
+
413
+ # Check for custom __init__ (not the one from BaseModel)
414
+ if "__init__" in cls.__dict__:
415
+ return False
416
+
417
+ return True
418
+
419
+
420
+ def _is_simple_dataclass(cls: type) -> bool:
421
+ """Check if a class is a simple dataclass without custom logic.
422
+
423
+ A simple dataclass:
424
+ - Is decorated with @dataclass
425
+ - Has no custom __init__ method (uses the generated one)
426
+ - Has no __post_init__ method
427
+
428
+ Returns False if cls is not a dataclass.
429
+ """
430
+ import dataclasses
431
+
432
+ if not dataclasses.is_dataclass(cls):
433
+ return False
434
+
435
+ # Check for __post_init__ which could have custom logic
436
+ if hasattr(cls, "__post_init__") and "__post_init__" in cls.__dict__:
437
+ return False
438
+
439
+ # Dataclasses generate __init__, so check if there's a custom one
440
+ # that overrides it (unlikely but possible)
441
+ # The dataclass decorator sets __init__ so we can't easily detect override
442
+ # We'll trust that dataclasses without __post_init__ are simple
443
+
444
+ return True
445
+
446
+
447
+ def _get_pydantic_model_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
448
+ """Extract field definitions from a Pydantic model."""
449
+ try:
450
+ from pydantic import BaseModel
451
+ from pydantic.fields import FieldInfo
452
+ except ImportError:
453
+ return {}
454
+
455
+ if not issubclass(cls, BaseModel):
456
+ return {}
457
+
458
+ fields: Dict[str, ModelFieldDefinition] = {}
459
+
460
+ # Pydantic v2 uses model_fields
461
+ model_fields = getattr(cls, "model_fields", {})
462
+ for field_name, field_info in model_fields.items():
463
+ has_default = False
464
+ default_value = None
465
+
466
+ if isinstance(field_info, FieldInfo):
467
+ # Check if field has a default value
468
+ # PydanticUndefined means no default
469
+ from pydantic_core import PydanticUndefined
470
+
471
+ if field_info.default is not PydanticUndefined:
472
+ has_default = True
473
+ default_value = field_info.default
474
+ elif field_info.default_factory is not None:
475
+ # We can't serialize factory functions, so treat as no default
476
+ has_default = False
477
+
478
+ fields[field_name] = ModelFieldDefinition(
479
+ name=field_name,
480
+ has_default=has_default,
481
+ default_value=default_value,
482
+ )
483
+
484
+ return fields
485
+
486
+
487
+ def _get_dataclass_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
488
+ """Extract field definitions from a dataclass."""
489
+ import dataclasses
490
+
491
+ if not dataclasses.is_dataclass(cls):
492
+ return {}
493
+
494
+ fields: Dict[str, ModelFieldDefinition] = {}
495
+ for field in dataclasses.fields(cls):
496
+ has_default = False
497
+ default_value = None
498
+
499
+ if field.default is not dataclasses.MISSING:
500
+ has_default = True
501
+ default_value = field.default
502
+ elif field.default_factory is not dataclasses.MISSING:
503
+ # We can't serialize factory functions, so treat as no default
504
+ has_default = False
505
+
506
+ fields[field.name] = ModelFieldDefinition(
507
+ name=field.name,
508
+ has_default=has_default,
509
+ default_value=default_value,
510
+ )
511
+
512
+ return fields
513
+
514
+
515
+ def _discover_model_definitions(module: Any) -> Dict[str, ModelDefinition]:
516
+ """Discover all Pydantic models and dataclasses that can be used in workflows.
517
+
518
+ Only discovers "simple" models without custom validators or __post_init__.
519
+ """
520
+ models: Dict[str, ModelDefinition] = {}
521
+
522
+ for attr_name in dir(module):
523
+ try:
524
+ attr = getattr(module, attr_name)
525
+ except AttributeError:
526
+ continue
527
+
528
+ if not isinstance(attr, type):
529
+ continue
530
+
531
+ # Check if this class is defined in this module or imported
532
+ # We want to include both, as models might be imported
533
+ if _is_simple_pydantic_model(attr):
534
+ fields = _get_pydantic_model_fields(attr)
535
+ models[attr_name] = ModelDefinition(
536
+ class_name=attr_name,
537
+ fields=fields,
538
+ is_pydantic=True,
539
+ )
540
+ elif _is_simple_dataclass(attr):
541
+ fields = _get_dataclass_fields(attr)
542
+ models[attr_name] = ModelDefinition(
543
+ class_name=attr_name,
544
+ fields=fields,
545
+ is_pydantic=False,
546
+ )
547
+
548
+ return models
549
+
550
+
551
+ def _discover_module_imports(module: Any) -> Dict[str, ImportedName]:
552
+ """Discover imports in a module by parsing its source.
553
+
554
+ Tracks imports like:
555
+ - from asyncio import sleep -> {"sleep": ImportedName("sleep", "asyncio", "sleep")}
556
+ - from asyncio import sleep as s -> {"s": ImportedName("s", "asyncio", "sleep")}
557
+ """
558
+ imported: Dict[str, ImportedName] = {}
559
+
560
+ try:
561
+ source = inspect.getsource(module)
562
+ tree = ast.parse(source)
563
+ except (OSError, TypeError):
564
+ # Can't get source (e.g., built-in module)
565
+ return imported
566
+
567
+ for node in ast.walk(tree):
568
+ if isinstance(node, ast.ImportFrom) and node.module:
569
+ for alias in node.names:
570
+ local_name = alias.asname if alias.asname else alias.name
571
+ imported[local_name] = ImportedName(
572
+ local_name=local_name,
573
+ module=node.module,
574
+ original_name=alias.name,
575
+ )
576
+
577
+ return imported
578
+
579
+
580
+ class IRBuilder(ast.NodeVisitor):
581
+ """Builds IR from Python AST with deep transformations."""
582
+
583
+ def __init__(
584
+ self,
585
+ action_defs: Dict[str, ActionDefinition],
586
+ ctx: TransformContext,
587
+ imported_names: Optional[Dict[str, ImportedName]] = None,
588
+ module_functions: Optional[Set[str]] = None,
589
+ model_defs: Optional[Dict[str, ModelDefinition]] = None,
590
+ ):
591
+ self._action_defs = action_defs
592
+ self._ctx = ctx
593
+ self._imported_names = imported_names or {}
594
+ self._module_functions = module_functions or set()
595
+ self._model_defs = model_defs or {}
596
+ self.function_def: Optional[ir.FunctionDef] = None
597
+ self._statements: List[ir.Statement] = []
598
+
599
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
600
+ """Visit a function definition (the workflow's run method)."""
601
+ # Extract inputs from function parameters (skip 'self')
602
+ inputs: List[str] = []
603
+ for arg in node.args.args[1:]: # Skip 'self'
604
+ inputs.append(arg.arg)
605
+
606
+ # Create the function definition
607
+ self.function_def = ir.FunctionDef(
608
+ name=node.name,
609
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
610
+ span=_make_span(node),
611
+ )
612
+
613
+ # Visit the body - _visit_statement now returns a list
614
+ self._statements = []
615
+ for stmt in node.body:
616
+ ir_stmts = self._visit_statement(stmt)
617
+ self._statements.extend(ir_stmts)
618
+
619
+ # Set the body
620
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
621
+
622
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
623
+ """Visit an async function definition (the workflow's run method)."""
624
+ # Handle async the same way as sync for IR building
625
+ inputs: List[str] = []
626
+ for arg in node.args.args[1:]: # Skip 'self'
627
+ inputs.append(arg.arg)
628
+
629
+ self.function_def = ir.FunctionDef(
630
+ name=node.name,
631
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
632
+ span=_make_span(node),
633
+ )
634
+
635
+ self._statements = []
636
+ for stmt in node.body:
637
+ ir_stmts = self._visit_statement(stmt)
638
+ self._statements.extend(ir_stmts)
639
+
640
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
641
+
642
+ def _visit_statement(self, node: ast.stmt) -> List[ir.Statement]:
643
+ """Convert a Python statement to IR Statement(s).
644
+
645
+ Returns a list because some transformations (like try block hoisting)
646
+ may expand a single Python statement into multiple IR statements.
647
+
648
+ Raises UnsupportedPatternError for unsupported statement types.
649
+ """
650
+ if isinstance(node, ast.Assign):
651
+ dict_expanded = self._expand_dict_comprehension_assignment(node)
652
+ if dict_expanded is not None:
653
+ return dict_expanded
654
+ expanded = self._expand_list_comprehension_assignment(node)
655
+ if expanded is not None:
656
+ return expanded
657
+ result = self._visit_assign(node)
658
+ return [result] if result else []
659
+ elif isinstance(node, ast.Expr):
660
+ result = self._visit_expr_stmt(node)
661
+ return [result] if result else []
662
+ elif isinstance(node, ast.For):
663
+ return self._visit_for(node)
664
+ elif isinstance(node, ast.If):
665
+ return self._visit_if(node)
666
+ elif isinstance(node, ast.Try):
667
+ return self._visit_try(node)
668
+ elif isinstance(node, ast.Return):
669
+ return self._visit_return(node)
670
+ elif isinstance(node, ast.AugAssign):
671
+ result = self._visit_aug_assign(node)
672
+ return [result] if result else []
673
+ elif isinstance(node, ast.Pass):
674
+ # Pass statements are fine, they just don't produce IR
675
+ return []
676
+
677
+ # Check for unsupported statement types
678
+ self._check_unsupported_statement(node)
679
+
680
+ return []
681
+
682
+ def _check_unsupported_statement(self, node: ast.stmt) -> None:
683
+ """Check for unsupported statement types and raise descriptive errors."""
684
+ line = getattr(node, "lineno", None)
685
+ col = getattr(node, "col_offset", None)
686
+
687
+ if isinstance(node, ast.While):
688
+ raise UnsupportedPatternError(
689
+ "While loops are not supported",
690
+ RECOMMENDATIONS["while_loop"],
691
+ line=line,
692
+ col=col,
693
+ )
694
+ elif isinstance(node, (ast.With, ast.AsyncWith)):
695
+ raise UnsupportedPatternError(
696
+ "Context managers (with statements) are not supported",
697
+ RECOMMENDATIONS["with_statement"],
698
+ line=line,
699
+ col=col,
700
+ )
701
+ elif isinstance(node, ast.Raise):
702
+ raise UnsupportedPatternError(
703
+ "The 'raise' statement is not supported",
704
+ RECOMMENDATIONS["raise_statement"],
705
+ line=line,
706
+ col=col,
707
+ )
708
+ elif isinstance(node, ast.Assert):
709
+ raise UnsupportedPatternError(
710
+ "Assert statements are not supported",
711
+ RECOMMENDATIONS["assert_statement"],
712
+ line=line,
713
+ col=col,
714
+ )
715
+ elif isinstance(node, ast.Delete):
716
+ raise UnsupportedPatternError(
717
+ "The 'del' statement is not supported",
718
+ RECOMMENDATIONS["delete"],
719
+ line=line,
720
+ col=col,
721
+ )
722
+ elif isinstance(node, ast.Global):
723
+ raise UnsupportedPatternError(
724
+ "Global statements are not supported",
725
+ RECOMMENDATIONS["global_statement"],
726
+ line=line,
727
+ col=col,
728
+ )
729
+ elif isinstance(node, ast.Nonlocal):
730
+ raise UnsupportedPatternError(
731
+ "Nonlocal statements are not supported",
732
+ RECOMMENDATIONS["nonlocal_statement"],
733
+ line=line,
734
+ col=col,
735
+ )
736
+ elif isinstance(node, (ast.Import, ast.ImportFrom)):
737
+ raise UnsupportedPatternError(
738
+ "Import statements inside workflow run() are not supported",
739
+ RECOMMENDATIONS["import_statement"],
740
+ line=line,
741
+ col=col,
742
+ )
743
+ elif isinstance(node, ast.ClassDef):
744
+ raise UnsupportedPatternError(
745
+ "Class definitions inside workflow run() are not supported",
746
+ RECOMMENDATIONS["class_def"],
747
+ line=line,
748
+ col=col,
749
+ )
750
+ elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
751
+ raise UnsupportedPatternError(
752
+ "Nested function definitions are not supported",
753
+ RECOMMENDATIONS["function_def"],
754
+ line=line,
755
+ col=col,
756
+ )
757
+ elif hasattr(ast, "Match") and isinstance(node, ast.Match):
758
+ raise UnsupportedPatternError(
759
+ "Match statements are not supported",
760
+ RECOMMENDATIONS["match"],
761
+ line=line,
762
+ col=col,
763
+ )
764
+
765
+ def _expand_list_comprehension_assignment(
766
+ self, node: ast.Assign
767
+ ) -> Optional[List[ir.Statement]]:
768
+ """Expand a list comprehension assignment into loop-based statements.
769
+
770
+ Example:
771
+ active_users = [user for user in users if user.active]
772
+
773
+ Becomes:
774
+ active_users = []
775
+ for user in users:
776
+ if user.active:
777
+ active_users = active_users + [user]
778
+ """
779
+ if not isinstance(node.value, ast.ListComp):
780
+ return None
781
+
782
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
783
+ line = getattr(node, "lineno", None)
784
+ col = getattr(node, "col_offset", None)
785
+ raise UnsupportedPatternError(
786
+ "List comprehension assignments must target a single variable",
787
+ "Assign the comprehension to a simple variable like `results = [x for x in items]`",
788
+ line=line,
789
+ col=col,
790
+ )
791
+
792
+ listcomp = node.value
793
+ if len(listcomp.generators) != 1:
794
+ line = getattr(listcomp, "lineno", None)
795
+ col = getattr(listcomp, "col_offset", None)
796
+ raise UnsupportedPatternError(
797
+ "List comprehensions with multiple generators are not supported",
798
+ "Use nested for loops instead of combining multiple generators in one comprehension",
799
+ line=line,
800
+ col=col,
801
+ )
802
+
803
+ gen = listcomp.generators[0]
804
+ if gen.is_async:
805
+ line = getattr(listcomp, "lineno", None)
806
+ col = getattr(listcomp, "col_offset", None)
807
+ raise UnsupportedPatternError(
808
+ "Async list comprehensions are not supported",
809
+ "Rewrite using an explicit async for loop",
810
+ line=line,
811
+ col=col,
812
+ )
813
+
814
+ target_name = node.targets[0].id
815
+
816
+ # Initialize the accumulator list: active_users = []
817
+ init_assign_ast = ast.Assign(
818
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
819
+ value=ast.List(elts=[], ctx=ast.Load()),
820
+ type_comment=None,
821
+ )
822
+ ast.copy_location(init_assign_ast, node)
823
+ ast.fix_missing_locations(init_assign_ast)
824
+
825
+ def _make_append_assignment(value_expr: ast.expr) -> ast.Assign:
826
+ append_assign = ast.Assign(
827
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
828
+ value=ast.BinOp(
829
+ left=ast.Name(id=target_name, ctx=ast.Load()),
830
+ op=ast.Add(),
831
+ right=ast.List(elts=[copy.deepcopy(value_expr)], ctx=ast.Load()),
832
+ ),
833
+ type_comment=None,
834
+ )
835
+ ast.copy_location(append_assign, node.value)
836
+ ast.fix_missing_locations(append_assign)
837
+ return append_assign
838
+
839
+ append_statements: List[ast.stmt] = []
840
+ if isinstance(listcomp.elt, ast.IfExp):
841
+ then_assign = _make_append_assignment(listcomp.elt.body)
842
+ else_assign = _make_append_assignment(listcomp.elt.orelse)
843
+ branch_if = ast.If(
844
+ test=copy.deepcopy(listcomp.elt.test),
845
+ body=[then_assign],
846
+ orelse=[else_assign],
847
+ )
848
+ ast.copy_location(branch_if, listcomp.elt)
849
+ ast.fix_missing_locations(branch_if)
850
+ append_statements.append(branch_if)
851
+ else:
852
+ append_statements.append(_make_append_assignment(listcomp.elt))
853
+
854
+ loop_body: List[ast.stmt] = append_statements
855
+ if gen.ifs:
856
+ condition: ast.expr
857
+ if len(gen.ifs) == 1:
858
+ condition = copy.deepcopy(gen.ifs[0])
859
+ else:
860
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
861
+ ast.copy_location(condition, gen.ifs[0])
862
+ if_stmt = ast.If(test=condition, body=append_statements, orelse=[])
863
+ ast.copy_location(if_stmt, gen.ifs[0])
864
+ ast.fix_missing_locations(if_stmt)
865
+ loop_body = [if_stmt]
866
+
867
+ loop_ast = ast.For(
868
+ target=copy.deepcopy(gen.target),
869
+ iter=copy.deepcopy(gen.iter),
870
+ body=loop_body,
871
+ orelse=[],
872
+ type_comment=None,
873
+ )
874
+ ast.copy_location(loop_ast, node)
875
+ ast.fix_missing_locations(loop_ast)
876
+
877
+ statements: List[ir.Statement] = []
878
+ init_stmt = self._visit_assign(init_assign_ast)
879
+ if init_stmt:
880
+ statements.append(init_stmt)
881
+ statements.extend(self._visit_for(loop_ast))
882
+
883
+ return statements
884
+
885
+ def _expand_dict_comprehension_assignment(
886
+ self, node: ast.Assign
887
+ ) -> Optional[List[ir.Statement]]:
888
+ """Expand a dict comprehension assignment into loop-based statements."""
889
+ if not isinstance(node.value, ast.DictComp):
890
+ return None
891
+
892
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
893
+ line = getattr(node, "lineno", None)
894
+ col = getattr(node, "col_offset", None)
895
+ raise UnsupportedPatternError(
896
+ "Dict comprehension assignments must target a single variable",
897
+ "Assign the comprehension to a simple variable like `result = {k: v for k, v in pairs}`",
898
+ line=line,
899
+ col=col,
900
+ )
901
+
902
+ dictcomp = node.value
903
+ if len(dictcomp.generators) != 1:
904
+ line = getattr(dictcomp, "lineno", None)
905
+ col = getattr(dictcomp, "col_offset", None)
906
+ raise UnsupportedPatternError(
907
+ "Dict comprehensions with multiple generators are not supported",
908
+ "Use nested for loops instead of combining multiple generators in one comprehension",
909
+ line=line,
910
+ col=col,
911
+ )
912
+
913
+ gen = dictcomp.generators[0]
914
+ if gen.is_async:
915
+ line = getattr(dictcomp, "lineno", None)
916
+ col = getattr(dictcomp, "col_offset", None)
917
+ raise UnsupportedPatternError(
918
+ "Async dict comprehensions are not supported",
919
+ "Rewrite using an explicit async for loop",
920
+ line=line,
921
+ col=col,
922
+ )
923
+
924
+ target_name = node.targets[0].id
925
+
926
+ # Initialize accumulator: result = {}
927
+ init_assign_ast = ast.Assign(
928
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
929
+ value=ast.Dict(keys=[], values=[]),
930
+ type_comment=None,
931
+ )
932
+ ast.copy_location(init_assign_ast, node)
933
+ ast.fix_missing_locations(init_assign_ast)
934
+
935
+ # result[key] = value
936
+ subscript_target = ast.Subscript(
937
+ value=ast.Name(id=target_name, ctx=ast.Load()),
938
+ slice=copy.deepcopy(dictcomp.key),
939
+ ctx=ast.Store(),
940
+ )
941
+ append_assign_ast = ast.Assign(
942
+ targets=[subscript_target],
943
+ value=copy.deepcopy(dictcomp.value),
944
+ type_comment=None,
945
+ )
946
+ ast.copy_location(append_assign_ast, node.value)
947
+ ast.fix_missing_locations(append_assign_ast)
948
+
949
+ loop_body: List[ast.stmt] = []
950
+ if gen.ifs:
951
+ condition: ast.expr
952
+ if len(gen.ifs) == 1:
953
+ condition = copy.deepcopy(gen.ifs[0])
954
+ else:
955
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
956
+ ast.copy_location(condition, gen.ifs[0])
957
+ if_stmt = ast.If(test=condition, body=[append_assign_ast], orelse=[])
958
+ ast.copy_location(if_stmt, gen.ifs[0])
959
+ ast.fix_missing_locations(if_stmt)
960
+ loop_body.append(if_stmt)
961
+ else:
962
+ loop_body.append(append_assign_ast)
963
+
964
+ loop_ast = ast.For(
965
+ target=copy.deepcopy(gen.target),
966
+ iter=copy.deepcopy(gen.iter),
967
+ body=loop_body,
968
+ orelse=[],
969
+ type_comment=None,
970
+ )
971
+ ast.copy_location(loop_ast, node)
972
+ ast.fix_missing_locations(loop_ast)
973
+
974
+ statements: List[ir.Statement] = []
975
+ init_stmt = self._visit_assign(init_assign_ast)
976
+ if init_stmt:
977
+ statements.append(init_stmt)
978
+ statements.extend(self._visit_for(loop_ast))
979
+
980
+ return statements
981
+
982
+ def _visit_assign(self, node: ast.Assign) -> Optional[ir.Statement]:
983
+ """Convert assignment to IR.
984
+
985
+ All assignments with targets use the Assignment statement type.
986
+ This provides uniform unpacking support for:
987
+ - Action calls: a, b = @get_pair()
988
+ - Parallel blocks: a, b = parallel: @x() @y()
989
+ - Regular expressions: a, b = some_list
990
+
991
+ Raises UnsupportedPatternError for:
992
+ - Constructor calls: x = MyClass(...)
993
+ - Non-action await: x = await some_func()
994
+ """
995
+ stmt = ir.Statement(span=_make_span(node))
996
+ targets = self._get_assign_targets(node.targets)
997
+
998
+ # Check for Pydantic model or dataclass constructor calls
999
+ # These are converted to dict expressions
1000
+ model_name = self._is_model_constructor(node.value)
1001
+ if model_name and isinstance(node.value, ast.Call):
1002
+ value_expr = self._convert_model_constructor_to_dict(node.value, model_name)
1003
+ assign = ir.Assignment(targets=targets, value=value_expr)
1004
+ stmt.assignment.CopyFrom(assign)
1005
+ return stmt
1006
+
1007
+ # Check for constructor calls in assignment (e.g., x = MyModel(...))
1008
+ # This must come AFTER the model constructor check since models are allowed
1009
+ self._check_constructor_in_assignment(node.value)
1010
+
1011
+ # Check for asyncio.gather() - convert to parallel or spread expression
1012
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1013
+ gather_call = node.value.value
1014
+ if self._is_asyncio_gather_call(gather_call):
1015
+ gather_result = self._convert_asyncio_gather(gather_call)
1016
+ if gather_result is not None:
1017
+ if isinstance(gather_result, ir.ParallelExpr):
1018
+ value = ir.Expr(parallel_expr=gather_result, span=_make_span(node))
1019
+ else:
1020
+ # SpreadExpr
1021
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1022
+ assign = ir.Assignment(targets=targets, value=value)
1023
+ stmt.assignment.CopyFrom(assign)
1024
+ return stmt
1025
+
1026
+ # Check if this is an action call - wrap in Assignment for uniform unpacking
1027
+ action_call = self._extract_action_call(node.value)
1028
+ if action_call:
1029
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1030
+ assign = ir.Assignment(targets=targets, value=value)
1031
+ stmt.assignment.CopyFrom(assign)
1032
+ return stmt
1033
+
1034
+ # Regular assignment (variables, literals, expressions)
1035
+ value_expr = _expr_to_ir(node.value)
1036
+ if value_expr:
1037
+ assign = ir.Assignment(targets=targets, value=value_expr)
1038
+ stmt.assignment.CopyFrom(assign)
1039
+ return stmt
1040
+
1041
+ return None
1042
+
1043
+ def _visit_expr_stmt(self, node: ast.Expr) -> Optional[ir.Statement]:
1044
+ """Convert expression statement to IR (side effect only, no assignment)."""
1045
+ stmt = ir.Statement(span=_make_span(node))
1046
+
1047
+ # Check for asyncio.gather() - convert to parallel block statement (side effect)
1048
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1049
+ gather_call = node.value.value
1050
+ if self._is_asyncio_gather_call(gather_call):
1051
+ gather_result = self._convert_asyncio_gather(gather_call)
1052
+ if gather_result is not None:
1053
+ if isinstance(gather_result, ir.ParallelExpr):
1054
+ # Side effect only - use ParallelBlock statement
1055
+ parallel = ir.ParallelBlock()
1056
+ parallel.calls.extend(gather_result.calls)
1057
+ stmt.parallel_block.CopyFrom(parallel)
1058
+ return stmt
1059
+ else:
1060
+ # SpreadExpr as side effect - wrap in assignment with no targets
1061
+ # This handles: await asyncio.gather(*[action(x) for x in items])
1062
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1063
+ assign = ir.Assignment(targets=[], value=value)
1064
+ stmt.assignment.CopyFrom(assign)
1065
+ return stmt
1066
+
1067
+ # Check if this is an action call (side effect only)
1068
+ action_call = self._extract_action_call(node.value)
1069
+ if action_call:
1070
+ stmt.action_call.CopyFrom(action_call)
1071
+ return stmt
1072
+
1073
+ # Convert list.append(x) to list = list + [x]
1074
+ # This makes the mutation explicit so data flows correctly through the DAG
1075
+ if isinstance(node.value, ast.Call):
1076
+ call = node.value
1077
+ if (
1078
+ isinstance(call.func, ast.Attribute)
1079
+ and call.func.attr == "append"
1080
+ and isinstance(call.func.value, ast.Name)
1081
+ and len(call.args) == 1
1082
+ ):
1083
+ list_name = call.func.value.id
1084
+ append_value = call.args[0]
1085
+ # Create: list = list + [value]
1086
+ list_var = ir.Expr(variable=ir.Variable(name=list_name), span=_make_span(node))
1087
+ value_expr = _expr_to_ir(append_value)
1088
+ if value_expr:
1089
+ # Create [value] as a list literal
1090
+ list_literal = ir.Expr(
1091
+ list=ir.ListExpr(elements=[value_expr]), span=_make_span(node)
1092
+ )
1093
+ # Create list + [value]
1094
+ concat_expr = ir.Expr(
1095
+ binary_op=ir.BinaryOp(
1096
+ op=ir.BinaryOperator.BINARY_OP_ADD, left=list_var, right=list_literal
1097
+ ),
1098
+ span=_make_span(node),
1099
+ )
1100
+ assign = ir.Assignment(targets=[list_name], value=concat_expr)
1101
+ stmt.assignment.CopyFrom(assign)
1102
+ return stmt
1103
+
1104
+ # Regular expression
1105
+ expr = _expr_to_ir(node.value)
1106
+ if expr:
1107
+ stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=expr))
1108
+ return stmt
1109
+
1110
+ return None
1111
+
1112
+ def _visit_for(self, node: ast.For) -> List[ir.Statement]:
1113
+ """Convert for loop to IR.
1114
+
1115
+ The loop body is emitted as a full block so it can contain multiple
1116
+ statements/calls and early `return`.
1117
+ """
1118
+ # Get loop variables
1119
+ loop_vars: List[str] = []
1120
+ if isinstance(node.target, ast.Name):
1121
+ loop_vars.append(node.target.id)
1122
+ elif isinstance(node.target, ast.Tuple):
1123
+ for elt in node.target.elts:
1124
+ if isinstance(elt, ast.Name):
1125
+ loop_vars.append(elt.id)
1126
+
1127
+ # Get iterable
1128
+ iterable = _expr_to_ir(node.iter)
1129
+ if not iterable:
1130
+ return []
1131
+
1132
+ # Build body statements (recursively transforms nested structures)
1133
+ body_stmts: List[ir.Statement] = []
1134
+ for body_node in node.body:
1135
+ stmts = self._visit_statement(body_node)
1136
+ body_stmts.extend(stmts)
1137
+
1138
+ stmt = ir.Statement(span=_make_span(node))
1139
+ for_loop = ir.ForLoop(
1140
+ loop_vars=loop_vars,
1141
+ iterable=iterable,
1142
+ block_body=ir.Block(statements=body_stmts, span=_make_span(node)),
1143
+ )
1144
+ stmt.for_loop.CopyFrom(for_loop)
1145
+ return [stmt]
1146
+
1147
+ def _detect_accumulator_targets(
1148
+ self, stmts: List[ir.Statement], in_scope_vars: set
1149
+ ) -> List[str]:
1150
+ """Detect out-of-scope variable modifications in for loop body.
1151
+
1152
+ Scans statements for patterns that modify variables defined outside the loop.
1153
+ Returns a list of accumulator variable names that should be set as targets.
1154
+
1155
+ Supported patterns:
1156
+ 1. List append: results.append(value) -> "results"
1157
+ 2. Dict subscript: result[key] = value -> "result"
1158
+ 3. List/set update methods: results.extend(...), results.add(...) -> "results"
1159
+
1160
+ Note: Patterns like `results = results + [x]` and `count = count + 1` create
1161
+ new assignments which are tracked via in_scope_vars and don't need special
1162
+ detection here - they're handled by the regular assignment target logic.
1163
+ """
1164
+ accumulators: List[str] = []
1165
+ seen: set = set()
1166
+
1167
+ for stmt in stmts:
1168
+ var_name = self._extract_accumulator_from_stmt(stmt, in_scope_vars)
1169
+ if var_name and var_name not in seen:
1170
+ accumulators.append(var_name)
1171
+ seen.add(var_name)
1172
+
1173
+ # Check conditionals for accumulator targets in branch bodies
1174
+ if stmt.HasField("conditional"):
1175
+ cond = stmt.conditional
1176
+ branch_blocks: list[ir.Block] = []
1177
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1178
+ branch_blocks.append(cond.if_branch.block_body)
1179
+ for branch in cond.elif_branches:
1180
+ if branch.HasField("block_body"):
1181
+ branch_blocks.append(branch.block_body)
1182
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1183
+ branch_blocks.append(cond.else_branch.block_body)
1184
+
1185
+ for block in branch_blocks:
1186
+ for var in self._detect_accumulator_targets(
1187
+ list(block.statements), in_scope_vars
1188
+ ):
1189
+ if var not in seen:
1190
+ accumulators.append(var)
1191
+ seen.add(var)
1192
+
1193
+ return accumulators
1194
+
1195
+ def _extract_accumulator_from_stmt(
1196
+ self, stmt: ir.Statement, in_scope_vars: set
1197
+ ) -> Optional[str]:
1198
+ """Extract accumulator variable name from a single statement.
1199
+
1200
+ Returns the variable name if this statement modifies an out-of-scope variable,
1201
+ None otherwise.
1202
+ """
1203
+ # Pattern 1: Method calls like list.append(), dict.update(), set.add()
1204
+ if stmt.HasField("expr_stmt"):
1205
+ expr = stmt.expr_stmt.expr
1206
+ if expr.HasField("function_call"):
1207
+ fn_name = expr.function_call.name
1208
+ # Check for mutating method calls: x.append, x.extend, x.add, x.update, etc.
1209
+ mutating_methods = {
1210
+ ".append",
1211
+ ".extend",
1212
+ ".add",
1213
+ ".update",
1214
+ ".insert",
1215
+ ".pop",
1216
+ ".remove",
1217
+ ".clear",
1218
+ }
1219
+ for method in mutating_methods:
1220
+ if fn_name.endswith(method):
1221
+ var_name = fn_name[: len(fn_name) - len(method)]
1222
+ # Only return if it's an out-of-scope variable
1223
+ if var_name and var_name not in in_scope_vars:
1224
+ return var_name
1225
+
1226
+ # Pattern 2: Subscript assignment like dict[key] = value
1227
+ if stmt.HasField("assignment"):
1228
+ for target in stmt.assignment.targets:
1229
+ # Check if target is a subscript pattern (contains '[')
1230
+ if "[" in target:
1231
+ # Extract base variable name (before '[')
1232
+ var_name = target.split("[")[0]
1233
+ if var_name and var_name not in in_scope_vars:
1234
+ return var_name
1235
+
1236
+ # Pattern 3: Self-referential assignment like x = x + [y]
1237
+ # The target variable is used on the RHS, so it must come from outside.
1238
+ # Note: We don't check in_scope_vars here because the assignment itself
1239
+ # would have added the target to in_scope_vars, but it still needs its
1240
+ # previous value from outside the loop body.
1241
+ if stmt.HasField("assignment"):
1242
+ assign = stmt.assignment
1243
+ rhs_vars = self._collect_variables_from_expr(assign.value)
1244
+ for target in assign.targets:
1245
+ if target in rhs_vars:
1246
+ return target
1247
+
1248
+ return None
1249
+
1250
+ def _collect_variables_from_expr(self, expr: ir.Expr) -> set:
1251
+ """Recursively collect all variable names used in an expression."""
1252
+ vars_found: set = set()
1253
+
1254
+ if expr.HasField("variable"):
1255
+ vars_found.add(expr.variable.name)
1256
+ elif expr.HasField("binary_op"):
1257
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.left))
1258
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.right))
1259
+ elif expr.HasField("unary_op"):
1260
+ vars_found.update(self._collect_variables_from_expr(expr.unary_op.operand))
1261
+ elif expr.HasField("list"):
1262
+ for elem in expr.list.elements:
1263
+ vars_found.update(self._collect_variables_from_expr(elem))
1264
+ elif expr.HasField("dict"):
1265
+ for key in expr.dict.keys:
1266
+ vars_found.update(self._collect_variables_from_expr(key))
1267
+ for val in expr.dict.values:
1268
+ vars_found.update(self._collect_variables_from_expr(val))
1269
+ elif expr.HasField("index"):
1270
+ vars_found.update(self._collect_variables_from_expr(expr.index.value))
1271
+ vars_found.update(self._collect_variables_from_expr(expr.index.index))
1272
+ elif expr.HasField("dot"):
1273
+ vars_found.update(self._collect_variables_from_expr(expr.dot.object))
1274
+ elif expr.HasField("function_call"):
1275
+ for kwarg in expr.function_call.kwargs:
1276
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1277
+ elif expr.HasField("action_call"):
1278
+ for kwarg in expr.action_call.kwargs:
1279
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1280
+
1281
+ return vars_found
1282
+
1283
+ def _visit_if(self, node: ast.If) -> List[ir.Statement]:
1284
+ """Convert if statement to IR.
1285
+
1286
+ Normalizes patterns like:
1287
+ if await some_action(...):
1288
+ ...
1289
+ into:
1290
+ __if_cond_n__ = await some_action(...)
1291
+ if __if_cond_n__:
1292
+ ...
1293
+ """
1294
+
1295
+ def normalize_condition(test: ast.expr) -> tuple[List[ir.Statement], Optional[ir.Expr]]:
1296
+ action_call = self._extract_action_call(test)
1297
+ if action_call is None:
1298
+ return ([], _expr_to_ir(test))
1299
+
1300
+ if not isinstance(test, ast.Await):
1301
+ line = getattr(test, "lineno", None)
1302
+ col = getattr(test, "col_offset", None)
1303
+ raise UnsupportedPatternError(
1304
+ "Action calls inside boolean expressions are not supported in if conditions",
1305
+ "Assign the awaited action result to a variable, then use the variable in the if condition.",
1306
+ line=line,
1307
+ col=col,
1308
+ )
1309
+
1310
+ cond_var = self._ctx.next_implicit_fn_name(prefix="if_cond")
1311
+ assign_stmt = ir.Statement(span=_make_span(test))
1312
+ assign_stmt.assignment.CopyFrom(
1313
+ ir.Assignment(
1314
+ targets=[cond_var],
1315
+ value=ir.Expr(action_call=action_call, span=_make_span(test)),
1316
+ )
1317
+ )
1318
+ cond_expr = ir.Expr(variable=ir.Variable(name=cond_var), span=_make_span(test))
1319
+ return ([assign_stmt], cond_expr)
1320
+
1321
+ def visit_body(nodes: list[ast.stmt]) -> List[ir.Statement]:
1322
+ stmts: List[ir.Statement] = []
1323
+ for body_node in nodes:
1324
+ stmts.extend(self._visit_statement(body_node))
1325
+ return stmts
1326
+
1327
+ # Collect if/elif branches as (test_expr, body_nodes)
1328
+ branches: list[tuple[ast.expr, list[ast.stmt], ast.AST]] = [(node.test, node.body, node)]
1329
+ current = node
1330
+ while current.orelse and len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If):
1331
+ elif_node = current.orelse[0]
1332
+ branches.append((elif_node.test, elif_node.body, elif_node))
1333
+ current = elif_node
1334
+
1335
+ else_nodes = current.orelse
1336
+
1337
+ normalized: list[
1338
+ tuple[List[ir.Statement], Optional[ir.Expr], List[ir.Statement], ast.AST]
1339
+ ] = []
1340
+ for test_expr, body_nodes, span_node in branches:
1341
+ prefix, cond = normalize_condition(test_expr)
1342
+ normalized.append((prefix, cond, visit_body(body_nodes), span_node))
1343
+
1344
+ else_body = visit_body(else_nodes) if else_nodes else []
1345
+
1346
+ # If any non-first branch needs normalization, preserve Python semantics by nesting.
1347
+ requires_nested = any(prefix for prefix, _, _, _ in normalized[1:])
1348
+
1349
+ def build_conditional_stmt(
1350
+ condition: ir.Expr,
1351
+ then_body: List[ir.Statement],
1352
+ else_body_statements: List[ir.Statement],
1353
+ span_node: ast.AST,
1354
+ ) -> ir.Statement:
1355
+ conditional_stmt = ir.Statement(span=_make_span(span_node))
1356
+ if_branch = ir.IfBranch(
1357
+ condition=condition,
1358
+ block_body=ir.Block(statements=then_body, span=_make_span(span_node)),
1359
+ span=_make_span(span_node),
1360
+ )
1361
+ conditional = ir.Conditional(if_branch=if_branch)
1362
+ if else_body_statements:
1363
+ else_branch = ir.ElseBranch(
1364
+ block_body=ir.Block(
1365
+ statements=else_body_statements,
1366
+ span=_make_span(span_node),
1367
+ ),
1368
+ span=_make_span(span_node),
1369
+ )
1370
+ conditional.else_branch.CopyFrom(else_branch)
1371
+ conditional_stmt.conditional.CopyFrom(conditional)
1372
+ return conditional_stmt
1373
+
1374
+ if requires_nested:
1375
+ nested_else: List[ir.Statement] = else_body
1376
+ for prefix, cond, then_body, span_node in reversed(normalized):
1377
+ if cond is None:
1378
+ continue
1379
+ nested_if_stmt = build_conditional_stmt(
1380
+ condition=cond,
1381
+ then_body=then_body,
1382
+ else_body_statements=nested_else,
1383
+ span_node=span_node,
1384
+ )
1385
+ nested_else = [*prefix, nested_if_stmt]
1386
+ return nested_else
1387
+
1388
+ # Flat conditional with elif/else (original behavior), plus optional prefix for the if guard.
1389
+ if_prefix, if_condition, if_body, if_span_node = normalized[0]
1390
+ if if_condition is None:
1391
+ return []
1392
+
1393
+ conditional_stmt = ir.Statement(span=_make_span(if_span_node))
1394
+ if_branch = ir.IfBranch(
1395
+ condition=if_condition,
1396
+ block_body=ir.Block(statements=if_body, span=_make_span(if_span_node)),
1397
+ span=_make_span(if_span_node),
1398
+ )
1399
+ conditional = ir.Conditional(if_branch=if_branch)
1400
+
1401
+ for _, elif_condition, elif_body, elif_span_node in normalized[1:]:
1402
+ if elif_condition is None:
1403
+ continue
1404
+ elif_branch = ir.ElifBranch(
1405
+ condition=elif_condition,
1406
+ block_body=ir.Block(statements=elif_body, span=_make_span(elif_span_node)),
1407
+ span=_make_span(elif_span_node),
1408
+ )
1409
+ conditional.elif_branches.append(elif_branch)
1410
+
1411
+ if else_body:
1412
+ else_branch = ir.ElseBranch(
1413
+ block_body=ir.Block(statements=else_body, span=_make_span(if_span_node)),
1414
+ span=_make_span(if_span_node),
1415
+ )
1416
+ conditional.else_branch.CopyFrom(else_branch)
1417
+
1418
+ conditional_stmt.conditional.CopyFrom(conditional)
1419
+ return [*if_prefix, conditional_stmt]
1420
+
1421
+ def _collect_assigned_vars(self, stmts: List[ir.Statement]) -> set:
1422
+ """Collect all variable names assigned in a list of statements."""
1423
+ assigned = set()
1424
+ for stmt in stmts:
1425
+ if stmt.HasField("assignment"):
1426
+ assigned.update(stmt.assignment.targets)
1427
+ return assigned
1428
+
1429
+ def _collect_assigned_vars_in_order(self, stmts: List[ir.Statement]) -> list[str]:
1430
+ """Collect assigned variable names in statement order (deduplicated)."""
1431
+ assigned: list[str] = []
1432
+ seen: set[str] = set()
1433
+
1434
+ for stmt in stmts:
1435
+ if stmt.HasField("assignment"):
1436
+ for target in stmt.assignment.targets:
1437
+ if target not in seen:
1438
+ seen.add(target)
1439
+ assigned.append(target)
1440
+
1441
+ if stmt.HasField("conditional"):
1442
+ cond = stmt.conditional
1443
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1444
+ for target in self._collect_assigned_vars_in_order(
1445
+ list(cond.if_branch.block_body.statements)
1446
+ ):
1447
+ if target not in seen:
1448
+ seen.add(target)
1449
+ assigned.append(target)
1450
+ for elif_branch in cond.elif_branches:
1451
+ if elif_branch.HasField("block_body"):
1452
+ for target in self._collect_assigned_vars_in_order(
1453
+ list(elif_branch.block_body.statements)
1454
+ ):
1455
+ if target not in seen:
1456
+ seen.add(target)
1457
+ assigned.append(target)
1458
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1459
+ for target in self._collect_assigned_vars_in_order(
1460
+ list(cond.else_branch.block_body.statements)
1461
+ ):
1462
+ if target not in seen:
1463
+ seen.add(target)
1464
+ assigned.append(target)
1465
+
1466
+ if stmt.HasField("for_loop") and stmt.for_loop.HasField("block_body"):
1467
+ for target in self._collect_assigned_vars_in_order(
1468
+ list(stmt.for_loop.block_body.statements)
1469
+ ):
1470
+ if target not in seen:
1471
+ seen.add(target)
1472
+ assigned.append(target)
1473
+
1474
+ if stmt.HasField("try_except"):
1475
+ try_block = stmt.try_except.try_block
1476
+ if try_block.HasField("span"):
1477
+ for target in self._collect_assigned_vars_in_order(list(try_block.statements)):
1478
+ if target not in seen:
1479
+ seen.add(target)
1480
+ assigned.append(target)
1481
+ for handler in stmt.try_except.handlers:
1482
+ if handler.HasField("block_body"):
1483
+ for target in self._collect_assigned_vars_in_order(
1484
+ list(handler.block_body.statements)
1485
+ ):
1486
+ if target not in seen:
1487
+ seen.add(target)
1488
+ assigned.append(target)
1489
+
1490
+ return assigned
1491
+
1492
+ def _collect_variables_from_block(self, block: ir.Block) -> list[str]:
1493
+ return self._collect_variables_from_statements(list(block.statements))
1494
+
1495
+ def _collect_variables_from_statements(self, stmts: List[ir.Statement]) -> list[str]:
1496
+ """Collect variable references from statements in encounter order."""
1497
+ vars_found: list[str] = []
1498
+ seen: set[str] = set()
1499
+
1500
+ for stmt in stmts:
1501
+ if stmt.HasField("assignment") and stmt.assignment.HasField("value"):
1502
+ for var in self._collect_variables_from_expr(stmt.assignment.value):
1503
+ if var not in seen:
1504
+ seen.add(var)
1505
+ vars_found.append(var)
1506
+
1507
+ if stmt.HasField("return_stmt") and stmt.return_stmt.HasField("value"):
1508
+ for var in self._collect_variables_from_expr(stmt.return_stmt.value):
1509
+ if var not in seen:
1510
+ seen.add(var)
1511
+ vars_found.append(var)
1512
+
1513
+ if stmt.HasField("action_call"):
1514
+ expr = ir.Expr(action_call=stmt.action_call, span=stmt.span)
1515
+ for var in self._collect_variables_from_expr(expr):
1516
+ if var not in seen:
1517
+ seen.add(var)
1518
+ vars_found.append(var)
1519
+
1520
+ if stmt.HasField("expr_stmt"):
1521
+ for var in self._collect_variables_from_expr(stmt.expr_stmt.expr):
1522
+ if var not in seen:
1523
+ seen.add(var)
1524
+ vars_found.append(var)
1525
+
1526
+ if stmt.HasField("conditional"):
1527
+ cond = stmt.conditional
1528
+ if cond.HasField("if_branch"):
1529
+ if cond.if_branch.HasField("condition"):
1530
+ for var in self._collect_variables_from_expr(cond.if_branch.condition):
1531
+ if var not in seen:
1532
+ seen.add(var)
1533
+ vars_found.append(var)
1534
+ if cond.if_branch.HasField("block_body"):
1535
+ for var in self._collect_variables_from_block(cond.if_branch.block_body):
1536
+ if var not in seen:
1537
+ seen.add(var)
1538
+ vars_found.append(var)
1539
+ for elif_branch in cond.elif_branches:
1540
+ if elif_branch.HasField("condition"):
1541
+ for var in self._collect_variables_from_expr(elif_branch.condition):
1542
+ if var not in seen:
1543
+ seen.add(var)
1544
+ vars_found.append(var)
1545
+ if elif_branch.HasField("block_body"):
1546
+ for var in self._collect_variables_from_block(elif_branch.block_body):
1547
+ if var not in seen:
1548
+ seen.add(var)
1549
+ vars_found.append(var)
1550
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1551
+ for var in self._collect_variables_from_block(cond.else_branch.block_body):
1552
+ if var not in seen:
1553
+ seen.add(var)
1554
+ vars_found.append(var)
1555
+
1556
+ if stmt.HasField("for_loop"):
1557
+ fl = stmt.for_loop
1558
+ if fl.HasField("iterable"):
1559
+ for var in self._collect_variables_from_expr(fl.iterable):
1560
+ if var not in seen:
1561
+ seen.add(var)
1562
+ vars_found.append(var)
1563
+ if fl.HasField("block_body"):
1564
+ for var in self._collect_variables_from_block(fl.block_body):
1565
+ if var not in seen:
1566
+ seen.add(var)
1567
+ vars_found.append(var)
1568
+
1569
+ if stmt.HasField("try_except"):
1570
+ te = stmt.try_except
1571
+ if te.HasField("try_block"):
1572
+ for var in self._collect_variables_from_block(te.try_block):
1573
+ if var not in seen:
1574
+ seen.add(var)
1575
+ vars_found.append(var)
1576
+ for handler in te.handlers:
1577
+ if handler.HasField("block_body"):
1578
+ for var in self._collect_variables_from_block(handler.block_body):
1579
+ if var not in seen:
1580
+ seen.add(var)
1581
+ vars_found.append(var)
1582
+
1583
+ if stmt.HasField("parallel_block"):
1584
+ for call in stmt.parallel_block.calls:
1585
+ if call.HasField("action"):
1586
+ for kwarg in call.action.kwargs:
1587
+ for var in self._collect_variables_from_expr(kwarg.value):
1588
+ if var not in seen:
1589
+ seen.add(var)
1590
+ vars_found.append(var)
1591
+ elif call.HasField("function"):
1592
+ for kwarg in call.function.kwargs:
1593
+ for var in self._collect_variables_from_expr(kwarg.value):
1594
+ if var not in seen:
1595
+ seen.add(var)
1596
+ vars_found.append(var)
1597
+
1598
+ if stmt.HasField("spread_action"):
1599
+ spread = stmt.spread_action
1600
+ if spread.HasField("collection"):
1601
+ for var in self._collect_variables_from_expr(spread.collection):
1602
+ if var not in seen:
1603
+ seen.add(var)
1604
+ vars_found.append(var)
1605
+ if spread.HasField("action"):
1606
+ for kwarg in spread.action.kwargs:
1607
+ for var in self._collect_variables_from_expr(kwarg.value):
1608
+ if var not in seen:
1609
+ seen.add(var)
1610
+ vars_found.append(var)
1611
+
1612
+ return vars_found
1613
+
1614
+ def _visit_try(self, node: ast.Try) -> List[ir.Statement]:
1615
+ """Convert try/except to IR with full block bodies."""
1616
+ # Build try body statements (recursively transforms nested structures)
1617
+ try_body: List[ir.Statement] = []
1618
+ for body_node in node.body:
1619
+ stmts = self._visit_statement(body_node)
1620
+ try_body.extend(stmts)
1621
+
1622
+ # Build exception handlers (with wrapping if needed)
1623
+ handlers: List[ir.ExceptHandler] = []
1624
+ for handler in node.handlers:
1625
+ exception_types: List[str] = []
1626
+ if handler.type:
1627
+ if isinstance(handler.type, ast.Name):
1628
+ exception_types.append(handler.type.id)
1629
+ elif isinstance(handler.type, ast.Tuple):
1630
+ for elt in handler.type.elts:
1631
+ if isinstance(elt, ast.Name):
1632
+ exception_types.append(elt.id)
1633
+
1634
+ # Build handler body (recursively transforms nested structures)
1635
+ handler_body: List[ir.Statement] = []
1636
+ for handler_node in handler.body:
1637
+ stmts = self._visit_statement(handler_node)
1638
+ handler_body.extend(stmts)
1639
+
1640
+ except_handler = ir.ExceptHandler(
1641
+ exception_types=exception_types,
1642
+ block_body=ir.Block(statements=handler_body, span=_make_span(handler)),
1643
+ span=_make_span(handler),
1644
+ )
1645
+ handlers.append(except_handler)
1646
+
1647
+ # Build the try/except statement
1648
+ try_stmt = ir.Statement(span=_make_span(node))
1649
+ try_except = ir.TryExcept(
1650
+ handlers=handlers,
1651
+ try_block=ir.Block(statements=try_body, span=_make_span(node)),
1652
+ )
1653
+ try_stmt.try_except.CopyFrom(try_except)
1654
+
1655
+ return [try_stmt]
1656
+
1657
+ def _count_calls(self, stmts: List[ir.Statement]) -> int:
1658
+ """Count action calls and function calls in statements.
1659
+
1660
+ Both action calls and function calls (including synthetic functions)
1661
+ count toward the limit of one call per control flow body.
1662
+ """
1663
+ count = 0
1664
+ for stmt in stmts:
1665
+ if stmt.HasField("action_call"):
1666
+ count += 1
1667
+ elif stmt.HasField("assignment"):
1668
+ # Check if assignment value is an action call or function call
1669
+ if stmt.assignment.value.HasField("action_call"):
1670
+ count += 1
1671
+ elif stmt.assignment.value.HasField("function_call"):
1672
+ count += 1
1673
+ elif stmt.HasField("expr_stmt"):
1674
+ # Check if expression is a function call
1675
+ if stmt.expr_stmt.expr.HasField("function_call"):
1676
+ count += 1
1677
+ return count
1678
+
1679
+ def _wrap_body_as_function(
1680
+ self,
1681
+ body: List[ir.Statement],
1682
+ prefix: str,
1683
+ node: ast.AST,
1684
+ inputs: Optional[List[str]] = None,
1685
+ modified_vars: Optional[List[str]] = None,
1686
+ ) -> List[ir.Statement]:
1687
+ """Wrap a body with multiple calls into a synthetic function.
1688
+
1689
+ Args:
1690
+ body: The statements to wrap
1691
+ prefix: Name prefix for the synthetic function
1692
+ node: AST node for span information
1693
+ inputs: Variables to pass as inputs (e.g., loop variables)
1694
+ modified_vars: Out-of-scope variables modified in the body.
1695
+ These are added as inputs AND returned as outputs,
1696
+ enabling functional transformation of external state.
1697
+
1698
+ Returns a list containing a single function call statement (or assignment
1699
+ if modified_vars are present).
1700
+ """
1701
+ fn_name = self._ctx.next_implicit_fn_name(prefix)
1702
+ fn_inputs = list(inputs or [])
1703
+
1704
+ # Add modified variables as inputs (they need to be passed in)
1705
+ modified_vars = modified_vars or []
1706
+ for var in modified_vars:
1707
+ if var not in fn_inputs:
1708
+ fn_inputs.append(var)
1709
+
1710
+ # If there are modified variables, add a return statement for them
1711
+ wrapped_body = list(body)
1712
+ if modified_vars:
1713
+ # Create return statement: return (var1, var2, ...) or return var1
1714
+ if len(modified_vars) == 1:
1715
+ return_expr = ir.Expr(
1716
+ variable=ir.Variable(name=modified_vars[0]),
1717
+ span=_make_span(node),
1718
+ )
1719
+ else:
1720
+ # Return as list (tuples are represented as lists in IR)
1721
+ return_expr = ir.Expr(
1722
+ list=ir.ListExpr(
1723
+ elements=[ir.Expr(variable=ir.Variable(name=var)) for var in modified_vars]
1724
+ ),
1725
+ span=_make_span(node),
1726
+ )
1727
+ return_stmt = ir.Statement(span=_make_span(node))
1728
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=return_expr))
1729
+ wrapped_body.append(return_stmt)
1730
+
1731
+ # Create the synthetic function
1732
+ implicit_fn = ir.FunctionDef(
1733
+ name=fn_name,
1734
+ io=ir.IoDecl(inputs=fn_inputs, outputs=modified_vars),
1735
+ body=ir.Block(statements=wrapped_body),
1736
+ span=_make_span(node),
1737
+ )
1738
+ self._ctx.implicit_functions.append(implicit_fn)
1739
+
1740
+ # Create a function call expression
1741
+ kwargs = [
1742
+ ir.Kwarg(name=var, value=ir.Expr(variable=ir.Variable(name=var))) for var in fn_inputs
1743
+ ]
1744
+ fn_call_expr = ir.Expr(
1745
+ function_call=ir.FunctionCall(name=fn_name, kwargs=kwargs),
1746
+ span=_make_span(node),
1747
+ )
1748
+
1749
+ # If there are modified variables, create an assignment statement
1750
+ # so the returned values are assigned back to the variables
1751
+ call_stmt = ir.Statement(span=_make_span(node))
1752
+ if modified_vars:
1753
+ # Create assignment: var1, var2 = fn(...) or var1 = fn(...)
1754
+ assign = ir.Assignment(value=fn_call_expr)
1755
+ assign.targets.extend(modified_vars)
1756
+ call_stmt.assignment.CopyFrom(assign)
1757
+ else:
1758
+ call_stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=fn_call_expr))
1759
+
1760
+ return [call_stmt]
1761
+
1762
+ def _visit_return(self, node: ast.Return) -> List[ir.Statement]:
1763
+ """Convert return statement to IR.
1764
+
1765
+ Return statements should only contain variables or literals, not action calls.
1766
+ If the return contains an action call, we normalize it:
1767
+ return await action()
1768
+ becomes:
1769
+ _return_tmp = await action()
1770
+ return _return_tmp
1771
+
1772
+ Constructor calls (like return MyModel(...)) are not supported and will
1773
+ raise an error with a recommendation to use an @action instead.
1774
+ """
1775
+ if node.value:
1776
+ # Check for constructor calls in return (e.g., return MyModel(...))
1777
+ self._check_constructor_in_return(node.value)
1778
+
1779
+ # Check if returning an action call - normalize to assignment + return
1780
+ action_call = self._extract_action_call(node.value)
1781
+ if action_call:
1782
+ # Create a temporary variable for the action result
1783
+ tmp_var = "_return_tmp"
1784
+
1785
+ # Create assignment: _return_tmp = await action()
1786
+ assign_stmt = ir.Statement(span=_make_span(node))
1787
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1788
+ assign = ir.Assignment(targets=[tmp_var], value=value)
1789
+ assign_stmt.assignment.CopyFrom(assign)
1790
+
1791
+ # Create return: return _return_tmp
1792
+ return_stmt = ir.Statement(span=_make_span(node))
1793
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
1794
+ ret = ir.ReturnStmt(value=var_expr)
1795
+ return_stmt.return_stmt.CopyFrom(ret)
1796
+
1797
+ return [assign_stmt, return_stmt]
1798
+
1799
+ # Regular return with expression (variable, literal, etc.)
1800
+ expr = _expr_to_ir(node.value)
1801
+ if expr:
1802
+ stmt = ir.Statement(span=_make_span(node))
1803
+ return_stmt = ir.ReturnStmt(value=expr)
1804
+ stmt.return_stmt.CopyFrom(return_stmt)
1805
+ return [stmt]
1806
+
1807
+ # Return with no value
1808
+ stmt = ir.Statement(span=_make_span(node))
1809
+ stmt.return_stmt.CopyFrom(ir.ReturnStmt())
1810
+ return [stmt]
1811
+
1812
+ def _visit_aug_assign(self, node: ast.AugAssign) -> Optional[ir.Statement]:
1813
+ """Convert augmented assignment (+=, -=, etc.) to IR."""
1814
+ # For now, we can represent this as a regular assignment with binary op
1815
+ # target op= value -> target = target op value
1816
+ stmt = ir.Statement(span=_make_span(node))
1817
+
1818
+ targets: List[str] = []
1819
+ if isinstance(node.target, ast.Name):
1820
+ targets.append(node.target.id)
1821
+
1822
+ left = _expr_to_ir(node.target)
1823
+ right = _expr_to_ir(node.value)
1824
+ if left and right:
1825
+ op = _bin_op_to_ir(node.op)
1826
+ if op:
1827
+ binary = ir.BinaryOp(left=left, op=op, right=right)
1828
+ value = ir.Expr(binary_op=binary)
1829
+ assign = ir.Assignment(targets=targets, value=value)
1830
+ stmt.assignment.CopyFrom(assign)
1831
+ return stmt
1832
+
1833
+ return None
1834
+
1835
+ def _check_constructor_in_return(self, node: ast.expr) -> None:
1836
+ """Check for constructor calls in return statements.
1837
+
1838
+ Raises UnsupportedPatternError if the return value is a class instantiation
1839
+ like: return MyModel(field=value)
1840
+
1841
+ This is not supported because the workflow IR cannot serialize arbitrary
1842
+ object instantiation. Users should use an @action to create objects.
1843
+ """
1844
+ # Skip if it's an await (action call) - those are fine
1845
+ if isinstance(node, ast.Await):
1846
+ return
1847
+
1848
+ # Check for direct Call that looks like a constructor
1849
+ if isinstance(node, ast.Call):
1850
+ func_name = self._get_constructor_name(node.func)
1851
+ if func_name and self._looks_like_constructor(func_name, node):
1852
+ line = getattr(node, "lineno", None)
1853
+ col = getattr(node, "col_offset", None)
1854
+ raise UnsupportedPatternError(
1855
+ f"Returning constructor call '{func_name}(...)' is not supported",
1856
+ RECOMMENDATIONS["constructor_return"],
1857
+ line=line,
1858
+ col=col,
1859
+ )
1860
+
1861
+ def _check_constructor_in_assignment(self, node: ast.expr) -> None:
1862
+ """Check for constructor calls in assignments.
1863
+
1864
+ Raises UnsupportedPatternError if the assignment value is a class instantiation
1865
+ like: result = MyModel(field=value)
1866
+
1867
+ This is not supported because the workflow IR cannot serialize arbitrary
1868
+ object instantiation. Users should use an @action to create objects.
1869
+ """
1870
+ # Skip if it's an await (action call) - those are fine
1871
+ if isinstance(node, ast.Await):
1872
+ return
1873
+
1874
+ # Check for direct Call that looks like a constructor
1875
+ if isinstance(node, ast.Call):
1876
+ func_name = self._get_constructor_name(node.func)
1877
+ if func_name and self._looks_like_constructor(func_name, node):
1878
+ line = getattr(node, "lineno", None)
1879
+ col = getattr(node, "col_offset", None)
1880
+ raise UnsupportedPatternError(
1881
+ f"Assigning constructor call '{func_name}(...)' is not supported",
1882
+ RECOMMENDATIONS["constructor_assignment"],
1883
+ line=line,
1884
+ col=col,
1885
+ )
1886
+
1887
+ def _get_constructor_name(self, func: ast.expr) -> Optional[str]:
1888
+ """Get the name from a function expression if it looks like a constructor."""
1889
+ if isinstance(func, ast.Name):
1890
+ return func.id
1891
+ elif isinstance(func, ast.Attribute):
1892
+ return func.attr
1893
+ return None
1894
+
1895
+ def _looks_like_constructor(self, func_name: str, call: ast.Call) -> bool:
1896
+ """Check if a function call looks like a class constructor.
1897
+
1898
+ A constructor is identified by:
1899
+ 1. Name starts with uppercase (PEP8 convention for classes)
1900
+ 2. It's not a known action
1901
+ 3. It's not a known builtin like String operations
1902
+ 4. It's not a known Pydantic model or dataclass (those are allowed)
1903
+
1904
+ This is a heuristic - we can't perfectly distinguish constructors
1905
+ from functions without full type information.
1906
+ """
1907
+ # Check if first letter is uppercase (class naming convention)
1908
+ if not func_name or not func_name[0].isupper():
1909
+ return False
1910
+
1911
+ # If it's a known action, it's not a constructor
1912
+ if func_name in self._action_defs:
1913
+ return False
1914
+
1915
+ # If it's a known Pydantic model or dataclass, allow it
1916
+ # (it will be converted to a dict expression)
1917
+ if func_name in self._model_defs:
1918
+ return False
1919
+
1920
+ # Common builtins that start with uppercase but aren't constructors
1921
+ # (these are rarely used in workflow code but let's be safe)
1922
+ builtin_exceptions = {"True", "False", "None", "Ellipsis"}
1923
+ if func_name in builtin_exceptions:
1924
+ return False
1925
+
1926
+ return True
1927
+
1928
+ def _is_model_constructor(self, node: ast.expr) -> Optional[str]:
1929
+ """Check if an expression is a Pydantic model or dataclass constructor call.
1930
+
1931
+ Returns the model name if it is, None otherwise.
1932
+ """
1933
+ if not isinstance(node, ast.Call):
1934
+ return None
1935
+
1936
+ func_name = self._get_constructor_name(node.func)
1937
+ if func_name and func_name in self._model_defs:
1938
+ return func_name
1939
+
1940
+ return None
1941
+
1942
+ def _convert_model_constructor_to_dict(self, node: ast.Call, model_name: str) -> ir.Expr:
1943
+ """Convert a Pydantic model or dataclass constructor call to a dict expression.
1944
+
1945
+ For example:
1946
+ MyModel(field1=value1, field2=value2)
1947
+ becomes:
1948
+ {"field1": value1, "field2": value2}
1949
+
1950
+ Default values from the model definition are included for fields not
1951
+ explicitly provided in the constructor call.
1952
+ """
1953
+ model_def = self._model_defs[model_name]
1954
+ entries: List[ir.DictEntry] = []
1955
+
1956
+ # Track which fields were explicitly provided
1957
+ provided_fields: Set[str] = set()
1958
+
1959
+ # First, add all explicitly provided kwargs
1960
+ for kw in node.keywords:
1961
+ if kw.arg is None:
1962
+ # **kwargs expansion - not supported
1963
+ line = getattr(node, "lineno", None)
1964
+ col = getattr(node, "col_offset", None)
1965
+ raise UnsupportedPatternError(
1966
+ f"Model constructor '{model_name}' with **kwargs is not supported",
1967
+ "Use explicit keyword arguments instead of **kwargs.",
1968
+ line=line,
1969
+ col=col,
1970
+ )
1971
+
1972
+ provided_fields.add(kw.arg)
1973
+ key_expr = ir.Expr()
1974
+ key_literal = ir.Literal()
1975
+ key_literal.string_value = kw.arg
1976
+ key_expr.literal.CopyFrom(key_literal)
1977
+
1978
+ value_expr = _expr_to_ir(kw.value)
1979
+ if value_expr is None:
1980
+ # If we can't convert the value, we need to raise an error
1981
+ line = getattr(node, "lineno", None)
1982
+ col = getattr(node, "col_offset", None)
1983
+ raise UnsupportedPatternError(
1984
+ f"Cannot convert value for field '{kw.arg}' in '{model_name}'",
1985
+ "Use simpler expressions (literals, variables, dicts, lists).",
1986
+ line=line,
1987
+ col=col,
1988
+ )
1989
+
1990
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
1991
+
1992
+ # Handle positional arguments (dataclasses support this)
1993
+ if node.args:
1994
+ # For dataclasses, positional args map to fields in order
1995
+ field_names = list(model_def.fields.keys())
1996
+ for i, arg in enumerate(node.args):
1997
+ if i >= len(field_names):
1998
+ line = getattr(node, "lineno", None)
1999
+ col = getattr(node, "col_offset", None)
2000
+ raise UnsupportedPatternError(
2001
+ f"Too many positional arguments for '{model_name}'",
2002
+ "Use keyword arguments for clarity.",
2003
+ line=line,
2004
+ col=col,
2005
+ )
2006
+
2007
+ field_name = field_names[i]
2008
+ provided_fields.add(field_name)
2009
+
2010
+ key_expr = ir.Expr()
2011
+ key_literal = ir.Literal()
2012
+ key_literal.string_value = field_name
2013
+ key_expr.literal.CopyFrom(key_literal)
2014
+
2015
+ value_expr = _expr_to_ir(arg)
2016
+ if value_expr is None:
2017
+ line = getattr(node, "lineno", None)
2018
+ col = getattr(node, "col_offset", None)
2019
+ raise UnsupportedPatternError(
2020
+ f"Cannot convert positional argument for field '{field_name}' in '{model_name}'",
2021
+ "Use simpler expressions (literals, variables, dicts, lists).",
2022
+ line=line,
2023
+ col=col,
2024
+ )
2025
+
2026
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2027
+
2028
+ # Add default values for fields not explicitly provided
2029
+ for field_name, field_def in model_def.fields.items():
2030
+ if field_name in provided_fields:
2031
+ continue
2032
+
2033
+ if field_def.has_default:
2034
+ key_expr = ir.Expr()
2035
+ key_literal = ir.Literal()
2036
+ key_literal.string_value = field_name
2037
+ key_expr.literal.CopyFrom(key_literal)
2038
+
2039
+ # Convert the default value to an IR literal
2040
+ default_literal = _constant_to_literal(field_def.default_value)
2041
+ if default_literal is None:
2042
+ # Can't serialize this default - skip it
2043
+ # (it's probably a complex object like a list factory)
2044
+ continue
2045
+
2046
+ value_expr = ir.Expr()
2047
+ value_expr.literal.CopyFrom(default_literal)
2048
+
2049
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2050
+
2051
+ result = ir.Expr(span=_make_span(node))
2052
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2053
+ return result
2054
+
2055
+ def _check_non_action_await(self, node: ast.Await) -> None:
2056
+ """Check if an await is for a non-action function.
2057
+
2058
+ Note: We can only reliably detect non-action awaits for functions defined
2059
+ in the same module. Actions imported from other modules will pass through
2060
+ and may fail at runtime if they're not actually actions.
2061
+
2062
+ For now, we only check against common builtins and known non-action patterns.
2063
+ A runtime check will catch functions that aren't registered actions.
2064
+ """
2065
+ awaited = node.value
2066
+ if not isinstance(awaited, ast.Call):
2067
+ return
2068
+
2069
+ # Skip special cases that are handled elsewhere
2070
+ if self._is_run_action_call(awaited):
2071
+ return
2072
+ if self._is_asyncio_sleep_call(awaited):
2073
+ return
2074
+ if self._is_asyncio_gather_call(awaited):
2075
+ return
2076
+
2077
+ # Get the function name
2078
+ func_name = None
2079
+ if isinstance(awaited.func, ast.Name):
2080
+ func_name = awaited.func.id
2081
+ elif isinstance(awaited.func, ast.Attribute):
2082
+ func_name = awaited.func.attr
2083
+
2084
+ if not func_name:
2085
+ return
2086
+
2087
+ # Only raise error for functions defined in THIS module that we know
2088
+ # are NOT actions (i.e., async functions without @action decorator)
2089
+ # We can't reliably detect imported non-actions without full type info.
2090
+ #
2091
+ # The check works by looking at _module_functions which contains
2092
+ # functions defined in the same module as the workflow.
2093
+ if func_name in getattr(self, "_module_functions", set()):
2094
+ if func_name not in self._action_defs:
2095
+ line = getattr(node, "lineno", None)
2096
+ col = getattr(node, "col_offset", None)
2097
+ raise UnsupportedPatternError(
2098
+ f"Awaiting non-action function '{func_name}()' is not supported",
2099
+ RECOMMENDATIONS["non_action_call"],
2100
+ line=line,
2101
+ col=col,
2102
+ )
2103
+
2104
+ def _check_sync_function_call(self, node: ast.Call) -> None:
2105
+ """Check for synchronous function calls that should be in actions.
2106
+
2107
+ Common patterns like len(), str(), etc. are not supported in workflow code.
2108
+ """
2109
+ func_name = None
2110
+ if isinstance(node.func, ast.Name):
2111
+ func_name = node.func.id
2112
+ elif isinstance(node.func, ast.Attribute):
2113
+ # Method calls on objects - check the method name
2114
+ func_name = node.func.attr
2115
+
2116
+ if not func_name:
2117
+ return
2118
+
2119
+ # Builtins that users commonly try to use
2120
+ common_builtins = {
2121
+ "len",
2122
+ "str",
2123
+ "int",
2124
+ "float",
2125
+ "bool",
2126
+ "list",
2127
+ "dict",
2128
+ "set",
2129
+ "tuple",
2130
+ "sum",
2131
+ "min",
2132
+ "max",
2133
+ "sorted",
2134
+ "reversed",
2135
+ "enumerate",
2136
+ "zip",
2137
+ "map",
2138
+ "filter",
2139
+ "range",
2140
+ "abs",
2141
+ "round",
2142
+ "print",
2143
+ "type",
2144
+ "isinstance",
2145
+ "hasattr",
2146
+ "getattr",
2147
+ "setattr",
2148
+ "open",
2149
+ "format",
2150
+ }
2151
+
2152
+ if func_name in common_builtins:
2153
+ line = getattr(node, "lineno", None)
2154
+ col = getattr(node, "col_offset", None)
2155
+ raise UnsupportedPatternError(
2156
+ f"Calling built-in function '{func_name}()' directly is not supported",
2157
+ RECOMMENDATIONS["builtin_call"],
2158
+ line=line,
2159
+ col=col,
2160
+ )
2161
+
2162
+ def _extract_action_call(self, node: ast.expr) -> Optional[ir.ActionCall]:
2163
+ """Extract an action call from an expression if present.
2164
+
2165
+ Also validates that awaited calls are actually @action decorated functions.
2166
+ Raises UnsupportedPatternError if awaiting a non-action function.
2167
+ """
2168
+ if not isinstance(node, ast.Await):
2169
+ return None
2170
+
2171
+ awaited = node.value
2172
+ # Handle self.run_action(...) wrapper
2173
+ if isinstance(awaited, ast.Call):
2174
+ if self._is_run_action_call(awaited):
2175
+ # Extract the actual action call from run_action
2176
+ if awaited.args:
2177
+ action_call = self._extract_action_call_from_awaitable(awaited.args[0])
2178
+ if action_call:
2179
+ # Extract policies from run_action kwargs (retry, timeout)
2180
+ self._extract_policies_from_run_action(awaited, action_call)
2181
+ return action_call
2182
+ # Check for asyncio.sleep() - convert to @sleep action
2183
+ if self._is_asyncio_sleep_call(awaited):
2184
+ return self._convert_asyncio_sleep_to_action(awaited)
2185
+ # Try to extract as action call
2186
+ action_call = self._extract_action_call_from_call(awaited)
2187
+ if action_call:
2188
+ return action_call
2189
+
2190
+ # If we get here, it's an await of a non-action function
2191
+ self._check_non_action_await(node)
2192
+ return None
2193
+
2194
+ return None
2195
+
2196
+ def _is_run_action_call(self, node: ast.Call) -> bool:
2197
+ """Check if this is a self.run_action(...) call."""
2198
+ if isinstance(node.func, ast.Attribute):
2199
+ return node.func.attr == "run_action"
2200
+ return False
2201
+
2202
+ def _extract_policies_from_run_action(
2203
+ self, run_action_call: ast.Call, action_call: ir.ActionCall
2204
+ ) -> None:
2205
+ """Extract retry and timeout policies from run_action kwargs.
2206
+
2207
+ Parses patterns like:
2208
+ - self.run_action(action(), retry=RetryPolicy(attempts=3))
2209
+ - self.run_action(action(), timeout=timedelta(seconds=30))
2210
+ - self.run_action(action(), timeout=60)
2211
+ """
2212
+ for kw in run_action_call.keywords:
2213
+ if kw.arg == "retry":
2214
+ retry_policy = self._parse_retry_policy(kw.value)
2215
+ if retry_policy:
2216
+ policy_bracket = ir.PolicyBracket()
2217
+ policy_bracket.retry.CopyFrom(retry_policy)
2218
+ action_call.policies.append(policy_bracket)
2219
+ elif kw.arg == "timeout":
2220
+ timeout_policy = self._parse_timeout_policy(kw.value)
2221
+ if timeout_policy:
2222
+ policy_bracket = ir.PolicyBracket()
2223
+ policy_bracket.timeout.CopyFrom(timeout_policy)
2224
+ action_call.policies.append(policy_bracket)
2225
+
2226
+ def _parse_retry_policy(self, node: ast.expr) -> Optional[ir.RetryPolicy]:
2227
+ """Parse a RetryPolicy(...) call into IR.
2228
+
2229
+ Supports:
2230
+ - RetryPolicy(attempts=3)
2231
+ - RetryPolicy(attempts=3, exception_types=["ValueError"])
2232
+ - RetryPolicy(attempts=3, backoff_seconds=5)
2233
+ """
2234
+ if not isinstance(node, ast.Call):
2235
+ return None
2236
+
2237
+ # Check if it's a RetryPolicy call
2238
+ func_name = None
2239
+ if isinstance(node.func, ast.Name):
2240
+ func_name = node.func.id
2241
+ elif isinstance(node.func, ast.Attribute):
2242
+ func_name = node.func.attr
2243
+
2244
+ if func_name != "RetryPolicy":
2245
+ return None
2246
+
2247
+ policy = ir.RetryPolicy()
2248
+
2249
+ for kw in node.keywords:
2250
+ if kw.arg == "attempts" and isinstance(kw.value, ast.Constant):
2251
+ policy.max_retries = kw.value.value
2252
+ elif kw.arg == "exception_types" and isinstance(kw.value, ast.List):
2253
+ for elt in kw.value.elts:
2254
+ if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
2255
+ policy.exception_types.append(elt.value)
2256
+ elif kw.arg == "backoff_seconds" and isinstance(kw.value, ast.Constant):
2257
+ policy.backoff.seconds = int(kw.value.value)
2258
+
2259
+ return policy
2260
+
2261
+ def _parse_timeout_policy(self, node: ast.expr) -> Optional[ir.TimeoutPolicy]:
2262
+ """Parse a timeout value into IR.
2263
+
2264
+ Supports:
2265
+ - timeout=60 (int seconds)
2266
+ - timeout=30.5 (float seconds)
2267
+ - timeout=timedelta(seconds=30)
2268
+ - timeout=timedelta(minutes=2)
2269
+ """
2270
+ policy = ir.TimeoutPolicy()
2271
+
2272
+ # Direct numeric value (seconds)
2273
+ if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
2274
+ policy.timeout.seconds = int(node.value)
2275
+ return policy
2276
+
2277
+ # timedelta(...) call
2278
+ if isinstance(node, ast.Call):
2279
+ func_name = None
2280
+ if isinstance(node.func, ast.Name):
2281
+ func_name = node.func.id
2282
+ elif isinstance(node.func, ast.Attribute):
2283
+ func_name = node.func.attr
2284
+
2285
+ if func_name == "timedelta":
2286
+ total_seconds = 0
2287
+ for kw in node.keywords:
2288
+ if isinstance(kw.value, ast.Constant):
2289
+ val = kw.value.value
2290
+ if kw.arg == "seconds":
2291
+ total_seconds += int(val)
2292
+ elif kw.arg == "minutes":
2293
+ total_seconds += int(val) * 60
2294
+ elif kw.arg == "hours":
2295
+ total_seconds += int(val) * 3600
2296
+ elif kw.arg == "days":
2297
+ total_seconds += int(val) * 86400
2298
+ policy.timeout.seconds = total_seconds
2299
+ return policy
2300
+
2301
+ return None
2302
+
2303
+ def _is_asyncio_sleep_call(self, node: ast.Call) -> bool:
2304
+ """Check if this is an asyncio.sleep(...) call.
2305
+
2306
+ Supports both patterns:
2307
+ - import asyncio; asyncio.sleep(1)
2308
+ - from asyncio import sleep; sleep(1)
2309
+ - from asyncio import sleep as s; s(1)
2310
+ """
2311
+ if isinstance(node.func, ast.Attribute):
2312
+ # asyncio.sleep(...) pattern
2313
+ if node.func.attr == "sleep" and isinstance(node.func.value, ast.Name):
2314
+ return node.func.value.id == "asyncio"
2315
+ elif isinstance(node.func, ast.Name):
2316
+ # sleep(...) pattern - check if it's imported from asyncio
2317
+ func_name = node.func.id
2318
+ if func_name in self._imported_names:
2319
+ imported = self._imported_names[func_name]
2320
+ return imported.module == "asyncio" and imported.original_name == "sleep"
2321
+ return False
2322
+
2323
+ def _convert_asyncio_sleep_to_action(self, node: ast.Call) -> ir.ActionCall:
2324
+ """Convert asyncio.sleep(duration) to @sleep(duration=X) action call.
2325
+
2326
+ This creates a built-in sleep action that the scheduler handles as a
2327
+ durable sleep - stored in the DB with a future scheduled_at time.
2328
+ """
2329
+ action_call = ir.ActionCall(action_name="sleep")
2330
+
2331
+ # Extract duration argument (positional or keyword)
2332
+ if node.args:
2333
+ # asyncio.sleep(1) - positional
2334
+ expr = _expr_to_ir(node.args[0])
2335
+ if expr:
2336
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2337
+ elif node.keywords:
2338
+ # asyncio.sleep(seconds=1) - keyword (less common)
2339
+ for kw in node.keywords:
2340
+ if kw.arg in ("seconds", "delay", "duration"):
2341
+ expr = _expr_to_ir(kw.value)
2342
+ if expr:
2343
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2344
+ break
2345
+
2346
+ return action_call
2347
+
2348
+ def _is_asyncio_gather_call(self, node: ast.Call) -> bool:
2349
+ """Check if this is an asyncio.gather(...) call.
2350
+
2351
+ Supports both patterns:
2352
+ - import asyncio; asyncio.gather(a(), b())
2353
+ - from asyncio import gather; gather(a(), b())
2354
+ - from asyncio import gather as g; g(a(), b())
2355
+ """
2356
+ if isinstance(node.func, ast.Attribute):
2357
+ # asyncio.gather(...) pattern
2358
+ if node.func.attr == "gather" and isinstance(node.func.value, ast.Name):
2359
+ return node.func.value.id == "asyncio"
2360
+ elif isinstance(node.func, ast.Name):
2361
+ # gather(...) pattern - check if it's imported from asyncio
2362
+ func_name = node.func.id
2363
+ if func_name in self._imported_names:
2364
+ imported = self._imported_names[func_name]
2365
+ return imported.module == "asyncio" and imported.original_name == "gather"
2366
+ return False
2367
+
2368
+ def _convert_asyncio_gather(
2369
+ self, node: ast.Call
2370
+ ) -> Optional[Union[ir.ParallelExpr, ir.SpreadExpr]]:
2371
+ """Convert asyncio.gather(...) to ParallelExpr or SpreadExpr IR.
2372
+
2373
+ Handles two patterns:
2374
+ 1. Static gather: asyncio.gather(a(), b(), c()) -> ParallelExpr
2375
+ 2. Spread gather: asyncio.gather(*[action(x) for x in items]) -> SpreadExpr
2376
+
2377
+ Args:
2378
+ node: The asyncio.gather() Call node
2379
+
2380
+ Returns:
2381
+ A ParallelExpr, SpreadExpr, or None if conversion fails.
2382
+ """
2383
+ # Check for starred expressions - spread pattern
2384
+ if len(node.args) == 1 and isinstance(node.args[0], ast.Starred):
2385
+ starred = node.args[0]
2386
+ # Only list comprehensions are supported for spread
2387
+ if isinstance(starred.value, ast.ListComp):
2388
+ return self._convert_listcomp_to_spread_expr(starred.value)
2389
+ else:
2390
+ # Spreading a variable or other expression is not supported
2391
+ line = getattr(node, "lineno", None)
2392
+ col = getattr(node, "col_offset", None)
2393
+ if isinstance(starred.value, ast.Name):
2394
+ var_name = starred.value.id
2395
+ raise UnsupportedPatternError(
2396
+ f"Spreading variable '{var_name}' in asyncio.gather() is not supported",
2397
+ RECOMMENDATIONS["gather_variable_spread"],
2398
+ line=line,
2399
+ col=col,
2400
+ )
2401
+ else:
2402
+ raise UnsupportedPatternError(
2403
+ "Spreading non-list-comprehension expressions in asyncio.gather() is not supported",
2404
+ RECOMMENDATIONS["gather_variable_spread"],
2405
+ line=line,
2406
+ col=col,
2407
+ )
2408
+
2409
+ # Standard case: gather(a(), b(), c()) -> ParallelExpr
2410
+ parallel = ir.ParallelExpr()
2411
+
2412
+ # Each argument to gather() should be an action call
2413
+ for arg in node.args:
2414
+ call = self._convert_gather_arg_to_call(arg)
2415
+ if call:
2416
+ parallel.calls.append(call)
2417
+
2418
+ # Only return if we have calls
2419
+ if not parallel.calls:
2420
+ return None
2421
+
2422
+ return parallel
2423
+
2424
+ def _convert_listcomp_to_spread_expr(self, listcomp: ast.ListComp) -> Optional[ir.SpreadExpr]:
2425
+ """Convert a list comprehension to SpreadExpr IR.
2426
+
2427
+ Handles patterns like:
2428
+ [action(x=item) for item in collection]
2429
+
2430
+ The comprehension must have exactly one generator with no conditions,
2431
+ and the element must be an action call.
2432
+
2433
+ Args:
2434
+ listcomp: The ListComp AST node
2435
+
2436
+ Returns:
2437
+ A SpreadExpr, or None if conversion fails.
2438
+ """
2439
+ # Only support simple list comprehensions with one generator
2440
+ if len(listcomp.generators) != 1:
2441
+ line = getattr(listcomp, "lineno", None)
2442
+ col = getattr(listcomp, "col_offset", None)
2443
+ raise UnsupportedPatternError(
2444
+ "Spread pattern only supports a single loop variable",
2445
+ "Use a simple list comprehension: [action(x) for x in items]",
2446
+ line=line,
2447
+ col=col,
2448
+ )
2449
+
2450
+ gen = listcomp.generators[0]
2451
+
2452
+ # Check for conditions - not supported
2453
+ if gen.ifs:
2454
+ line = getattr(listcomp, "lineno", None)
2455
+ col = getattr(listcomp, "col_offset", None)
2456
+ raise UnsupportedPatternError(
2457
+ "Spread pattern does not support conditions in list comprehension",
2458
+ "Remove the 'if' clause from the comprehension",
2459
+ line=line,
2460
+ col=col,
2461
+ )
2462
+
2463
+ # Get the loop variable name
2464
+ if not isinstance(gen.target, ast.Name):
2465
+ line = getattr(listcomp, "lineno", None)
2466
+ col = getattr(listcomp, "col_offset", None)
2467
+ raise UnsupportedPatternError(
2468
+ "Spread pattern requires a simple loop variable",
2469
+ "Use a simple variable: [action(x) for x in items]",
2470
+ line=line,
2471
+ col=col,
2472
+ )
2473
+ loop_var = gen.target.id
2474
+
2475
+ # Get the collection expression
2476
+ collection_expr = _expr_to_ir(gen.iter)
2477
+ if not collection_expr:
2478
+ line = getattr(listcomp, "lineno", None)
2479
+ col = getattr(listcomp, "col_offset", None)
2480
+ raise UnsupportedPatternError(
2481
+ "Could not convert collection expression in spread pattern",
2482
+ "Ensure the collection is a simple variable or expression",
2483
+ line=line,
2484
+ col=col,
2485
+ )
2486
+
2487
+ # The element must be an action call
2488
+ if not isinstance(listcomp.elt, ast.Call):
2489
+ line = getattr(listcomp, "lineno", None)
2490
+ col = getattr(listcomp, "col_offset", None)
2491
+ raise UnsupportedPatternError(
2492
+ "Spread pattern requires an action call in the list comprehension",
2493
+ "Use: [action(x=item) for item in items]",
2494
+ line=line,
2495
+ col=col,
2496
+ )
2497
+
2498
+ action_call = self._extract_action_call_from_call(listcomp.elt)
2499
+ if not action_call:
2500
+ line = getattr(listcomp, "lineno", None)
2501
+ col = getattr(listcomp, "col_offset", None)
2502
+ raise UnsupportedPatternError(
2503
+ "Spread pattern element must be an @action call",
2504
+ "Ensure the function is decorated with @action",
2505
+ line=line,
2506
+ col=col,
2507
+ )
2508
+
2509
+ # Build the SpreadExpr
2510
+ spread = ir.SpreadExpr()
2511
+ spread.collection.CopyFrom(collection_expr)
2512
+ spread.loop_var = loop_var
2513
+ spread.action.CopyFrom(action_call)
2514
+
2515
+ return spread
2516
+
2517
+ def _convert_gather_arg_to_call(self, node: ast.expr) -> Optional[ir.Call]:
2518
+ """Convert a gather argument to an IR Call.
2519
+
2520
+ Handles both action calls and regular function calls.
2521
+ """
2522
+ if not isinstance(node, ast.Call):
2523
+ return None
2524
+
2525
+ # Try to extract as an action call first
2526
+ action_call = self._extract_action_call_from_call(node)
2527
+ if action_call:
2528
+ call = ir.Call()
2529
+ call.action.CopyFrom(action_call)
2530
+ return call
2531
+
2532
+ # Fall back to regular function call
2533
+ func_call = self._convert_to_function_call(node)
2534
+ if func_call:
2535
+ call = ir.Call()
2536
+ call.function.CopyFrom(func_call)
2537
+ return call
2538
+
2539
+ return None
2540
+
2541
+ def _convert_to_function_call(self, node: ast.Call) -> Optional[ir.FunctionCall]:
2542
+ """Convert an AST Call to IR FunctionCall."""
2543
+ func_name = self._get_func_name(node.func)
2544
+ if not func_name:
2545
+ return None
2546
+
2547
+ fn_call = ir.FunctionCall(name=func_name)
2548
+
2549
+ # Add positional args
2550
+ for arg in node.args:
2551
+ expr = _expr_to_ir(arg)
2552
+ if expr:
2553
+ fn_call.args.append(expr)
2554
+
2555
+ # Add keyword args
2556
+ for kw in node.keywords:
2557
+ if kw.arg:
2558
+ expr = _expr_to_ir(kw.value)
2559
+ if expr:
2560
+ fn_call.kwargs.append(ir.Kwarg(name=kw.arg, value=expr))
2561
+
2562
+ return fn_call
2563
+
2564
+ def _get_func_name(self, node: ast.expr) -> Optional[str]:
2565
+ """Get function name from a func node."""
2566
+ if isinstance(node, ast.Name):
2567
+ return node.id
2568
+ elif isinstance(node, ast.Attribute):
2569
+ # Handle chained attributes like obj.method
2570
+ parts = []
2571
+ current = node
2572
+ while isinstance(current, ast.Attribute):
2573
+ parts.append(current.attr)
2574
+ current = current.value
2575
+ if isinstance(current, ast.Name):
2576
+ parts.append(current.id)
2577
+ return ".".join(reversed(parts))
2578
+ return None
2579
+
2580
+ def _expr_to_ir_with_model_coercion(self, node: ast.expr) -> Optional[ir.Expr]:
2581
+ """Convert an AST expression to IR, converting model constructors to dicts.
2582
+
2583
+ This is used for action arguments where Pydantic models or dataclass
2584
+ constructors should be converted to dict expressions that Rust can evaluate.
2585
+
2586
+ If the expression is a model constructor (e.g., MyModel(field=value)),
2587
+ it is converted to a dict expression. Otherwise, falls back to the
2588
+ standard _expr_to_ir conversion.
2589
+ """
2590
+ # Check if this is a model constructor call
2591
+ if isinstance(node, ast.Call):
2592
+ model_name = self._is_model_constructor(node)
2593
+ if model_name:
2594
+ return self._convert_model_constructor_to_dict(node, model_name)
2595
+
2596
+ # Fall back to standard expression conversion
2597
+ return _expr_to_ir(node)
2598
+
2599
+ def _extract_action_call_from_awaitable(self, node: ast.expr) -> Optional[ir.ActionCall]:
2600
+ """Extract action call from an awaitable expression."""
2601
+ if isinstance(node, ast.Call):
2602
+ return self._extract_action_call_from_call(node)
2603
+ return None
2604
+
2605
+ def _extract_action_call_from_call(self, node: ast.Call) -> Optional[ir.ActionCall]:
2606
+ """Extract action call info from a Call node.
2607
+
2608
+ Converts positional arguments to keyword arguments using the action's
2609
+ signature introspection. This ensures all arguments are named in the IR.
2610
+
2611
+ Pydantic models and dataclass constructors passed as arguments are
2612
+ automatically converted to dict expressions.
2613
+ """
2614
+ action_name = self._get_action_name(node.func)
2615
+ if not action_name:
2616
+ return None
2617
+
2618
+ if action_name not in self._action_defs:
2619
+ return None
2620
+
2621
+ action_def = self._action_defs[action_name]
2622
+ action_call = ir.ActionCall(action_name=action_def.action_name)
2623
+
2624
+ # Set the module name so the worker knows where to find the action
2625
+ if action_def.module_name:
2626
+ action_call.module_name = action_def.module_name
2627
+
2628
+ # Get parameter names from signature for positional arg conversion
2629
+ param_names = list(action_def.signature.parameters.keys())
2630
+
2631
+ # Convert positional args to kwargs using signature introspection
2632
+ # Model constructors are converted to dict expressions
2633
+ for i, arg in enumerate(node.args):
2634
+ if i < len(param_names):
2635
+ expr = self._expr_to_ir_with_model_coercion(arg)
2636
+ if expr:
2637
+ kwarg = ir.Kwarg(name=param_names[i], value=expr)
2638
+ action_call.kwargs.append(kwarg)
2639
+
2640
+ # Add explicit kwargs
2641
+ # Model constructors are converted to dict expressions
2642
+ for kw in node.keywords:
2643
+ if kw.arg:
2644
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
2645
+ if expr:
2646
+ kwarg = ir.Kwarg(name=kw.arg, value=expr)
2647
+ action_call.kwargs.append(kwarg)
2648
+
2649
+ return action_call
2650
+
2651
+ def _get_action_name(self, func: ast.expr) -> Optional[str]:
2652
+ """Get the action name from a function expression."""
2653
+ if isinstance(func, ast.Name):
2654
+ return func.id
2655
+ elif isinstance(func, ast.Attribute):
2656
+ return func.attr
2657
+ return None
2658
+
2659
+ def _get_assign_target(self, targets: List[ast.expr]) -> Optional[str]:
2660
+ """Get the target variable name from assignment targets (single target only)."""
2661
+ if targets and isinstance(targets[0], ast.Name):
2662
+ return targets[0].id
2663
+ return None
2664
+
2665
+ def _get_assign_targets(self, targets: List[ast.expr]) -> List[str]:
2666
+ """Get all target variable names from assignment targets (including tuple unpacking)."""
2667
+ result: List[str] = []
2668
+ for t in targets:
2669
+ if isinstance(t, ast.Name):
2670
+ result.append(t.id)
2671
+ elif isinstance(t, ast.Subscript):
2672
+ formatted = _format_subscript_target(t)
2673
+ if formatted:
2674
+ result.append(formatted)
2675
+ elif isinstance(t, ast.Tuple):
2676
+ for elt in t.elts:
2677
+ if isinstance(elt, ast.Name):
2678
+ result.append(elt.id)
2679
+ return result
2680
+
2681
+
2682
+ def _make_span(node: ast.AST) -> ir.Span:
2683
+ """Create a Span from an AST node."""
2684
+ return ir.Span(
2685
+ start_line=getattr(node, "lineno", 0),
2686
+ start_col=getattr(node, "col_offset", 0),
2687
+ end_line=getattr(node, "end_lineno", 0) or 0,
2688
+ end_col=getattr(node, "end_col_offset", 0) or 0,
2689
+ )
2690
+
2691
+
2692
+ def _expr_to_ir(expr: ast.AST) -> Optional[ir.Expr]:
2693
+ """Convert Python AST expression to IR Expr."""
2694
+ result = ir.Expr(span=_make_span(expr))
2695
+
2696
+ if isinstance(expr, ast.Name):
2697
+ result.variable.CopyFrom(ir.Variable(name=expr.id))
2698
+ return result
2699
+
2700
+ if isinstance(expr, ast.Constant):
2701
+ literal = _constant_to_literal(expr.value)
2702
+ if literal:
2703
+ result.literal.CopyFrom(literal)
2704
+ return result
2705
+
2706
+ if isinstance(expr, ast.BinOp):
2707
+ left = _expr_to_ir(expr.left)
2708
+ right = _expr_to_ir(expr.right)
2709
+ op = _bin_op_to_ir(expr.op)
2710
+ if left and right and op:
2711
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
2712
+ return result
2713
+
2714
+ if isinstance(expr, ast.UnaryOp):
2715
+ operand = _expr_to_ir(expr.operand)
2716
+ op = _unary_op_to_ir(expr.op)
2717
+ if operand and op:
2718
+ result.unary_op.CopyFrom(ir.UnaryOp(op=op, operand=operand))
2719
+ return result
2720
+
2721
+ if isinstance(expr, ast.Compare):
2722
+ left = _expr_to_ir(expr.left)
2723
+ if not left:
2724
+ return None
2725
+ # For simplicity, handle single comparison
2726
+ if expr.ops and expr.comparators:
2727
+ op = _cmp_op_to_ir(expr.ops[0])
2728
+ right = _expr_to_ir(expr.comparators[0])
2729
+ if op and right:
2730
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
2731
+ return result
2732
+
2733
+ if isinstance(expr, ast.BoolOp):
2734
+ values = [_expr_to_ir(v) for v in expr.values]
2735
+ if all(v for v in values):
2736
+ op = _bool_op_to_ir(expr.op)
2737
+ if op and len(values) >= 2:
2738
+ # Chain boolean ops: a and b and c -> (a and b) and c
2739
+ result_expr = values[0]
2740
+ for v in values[1:]:
2741
+ if result_expr and v:
2742
+ new_result = ir.Expr()
2743
+ new_result.binary_op.CopyFrom(ir.BinaryOp(left=result_expr, op=op, right=v))
2744
+ result_expr = new_result
2745
+ return result_expr
2746
+
2747
+ if isinstance(expr, ast.List):
2748
+ elements = [_expr_to_ir(e) for e in expr.elts]
2749
+ if all(e for e in elements):
2750
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
2751
+ result.list.CopyFrom(list_expr)
2752
+ return result
2753
+
2754
+ if isinstance(expr, ast.Dict):
2755
+ entries: List[ir.DictEntry] = []
2756
+ for k, v in zip(expr.keys, expr.values, strict=False):
2757
+ if k:
2758
+ key_expr = _expr_to_ir(k)
2759
+ value_expr = _expr_to_ir(v)
2760
+ if key_expr and value_expr:
2761
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2762
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2763
+ return result
2764
+
2765
+ if isinstance(expr, ast.Subscript):
2766
+ obj = _expr_to_ir(expr.value)
2767
+ index = _expr_to_ir(expr.slice) if isinstance(expr.slice, ast.AST) else None
2768
+ if obj and index:
2769
+ result.index.CopyFrom(ir.IndexAccess(object=obj, index=index))
2770
+ return result
2771
+
2772
+ if isinstance(expr, ast.Attribute):
2773
+ obj = _expr_to_ir(expr.value)
2774
+ if obj:
2775
+ result.dot.CopyFrom(ir.DotAccess(object=obj, attribute=expr.attr))
2776
+ return result
2777
+
2778
+ if isinstance(expr, ast.Call):
2779
+ # Function call
2780
+ func_name = _get_func_name(expr.func)
2781
+ if func_name:
2782
+ args = [_expr_to_ir(a) for a in expr.args]
2783
+ kwargs: List[ir.Kwarg] = []
2784
+ for kw in expr.keywords:
2785
+ if kw.arg:
2786
+ kw_expr = _expr_to_ir(kw.value)
2787
+ if kw_expr:
2788
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
2789
+ func_call = ir.FunctionCall(
2790
+ name=func_name,
2791
+ args=[a for a in args if a],
2792
+ kwargs=kwargs,
2793
+ )
2794
+ result.function_call.CopyFrom(func_call)
2795
+ return result
2796
+
2797
+ if isinstance(expr, ast.Tuple):
2798
+ # Handle tuple as list for now
2799
+ elements = [_expr_to_ir(e) for e in expr.elts]
2800
+ if all(e for e in elements):
2801
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
2802
+ result.list.CopyFrom(list_expr)
2803
+ return result
2804
+
2805
+ # Check for unsupported expression types
2806
+ _check_unsupported_expression(expr)
2807
+
2808
+ return None
2809
+
2810
+
2811
+ def _check_unsupported_expression(expr: ast.AST) -> None:
2812
+ """Check for unsupported expression types and raise descriptive errors."""
2813
+ line = getattr(expr, "lineno", None)
2814
+ col = getattr(expr, "col_offset", None)
2815
+
2816
+ if isinstance(expr, ast.JoinedStr):
2817
+ raise UnsupportedPatternError(
2818
+ "F-strings are not supported",
2819
+ RECOMMENDATIONS["fstring"],
2820
+ line=line,
2821
+ col=col,
2822
+ )
2823
+ elif isinstance(expr, ast.Lambda):
2824
+ raise UnsupportedPatternError(
2825
+ "Lambda expressions are not supported",
2826
+ RECOMMENDATIONS["lambda"],
2827
+ line=line,
2828
+ col=col,
2829
+ )
2830
+ elif isinstance(expr, ast.ListComp):
2831
+ raise UnsupportedPatternError(
2832
+ "List comprehensions are not supported in this context",
2833
+ RECOMMENDATIONS["list_comprehension"],
2834
+ line=line,
2835
+ col=col,
2836
+ )
2837
+ elif isinstance(expr, ast.DictComp):
2838
+ raise UnsupportedPatternError(
2839
+ "Dict comprehensions are not supported in this context",
2840
+ RECOMMENDATIONS["dict_comprehension"],
2841
+ line=line,
2842
+ col=col,
2843
+ )
2844
+ elif isinstance(expr, ast.SetComp):
2845
+ raise UnsupportedPatternError(
2846
+ "Set comprehensions are not supported",
2847
+ RECOMMENDATIONS["set_comprehension"],
2848
+ line=line,
2849
+ col=col,
2850
+ )
2851
+ elif isinstance(expr, ast.GeneratorExp):
2852
+ raise UnsupportedPatternError(
2853
+ "Generator expressions are not supported",
2854
+ RECOMMENDATIONS["generator"],
2855
+ line=line,
2856
+ col=col,
2857
+ )
2858
+ elif isinstance(expr, ast.NamedExpr):
2859
+ raise UnsupportedPatternError(
2860
+ "The walrus operator (:=) is not supported",
2861
+ RECOMMENDATIONS["walrus"],
2862
+ line=line,
2863
+ col=col,
2864
+ )
2865
+ elif isinstance(expr, ast.Yield) or isinstance(expr, ast.YieldFrom):
2866
+ raise UnsupportedPatternError(
2867
+ "Yield expressions are not supported",
2868
+ RECOMMENDATIONS["yield_statement"],
2869
+ line=line,
2870
+ col=col,
2871
+ )
2872
+
2873
+
2874
+ def _format_subscript_target(target: ast.Subscript) -> Optional[str]:
2875
+ """Convert a subscript target to a string representation for tracking."""
2876
+ if not isinstance(target.value, ast.Name):
2877
+ return None
2878
+
2879
+ base = target.value.id
2880
+ try:
2881
+ index_str = ast.unparse(target.slice)
2882
+ except Exception:
2883
+ return None
2884
+
2885
+ return f"{base}[{index_str}]"
2886
+
2887
+
2888
+ def _constant_to_literal(value: Any) -> Optional[ir.Literal]:
2889
+ """Convert a Python constant to IR Literal."""
2890
+ literal = ir.Literal()
2891
+ if value is None:
2892
+ literal.is_none = True
2893
+ elif isinstance(value, bool):
2894
+ literal.bool_value = value
2895
+ elif isinstance(value, int):
2896
+ literal.int_value = value
2897
+ elif isinstance(value, float):
2898
+ literal.float_value = value
2899
+ elif isinstance(value, str):
2900
+ literal.string_value = value
2901
+ else:
2902
+ return None
2903
+ return literal
2904
+
2905
+
2906
+ def _bin_op_to_ir(op: ast.operator) -> Optional[ir.BinaryOperator]:
2907
+ """Convert Python binary operator to IR BinaryOperator."""
2908
+ mapping = {
2909
+ ast.Add: ir.BinaryOperator.BINARY_OP_ADD,
2910
+ ast.Sub: ir.BinaryOperator.BINARY_OP_SUB,
2911
+ ast.Mult: ir.BinaryOperator.BINARY_OP_MUL,
2912
+ ast.Div: ir.BinaryOperator.BINARY_OP_DIV,
2913
+ ast.FloorDiv: ir.BinaryOperator.BINARY_OP_FLOOR_DIV,
2914
+ ast.Mod: ir.BinaryOperator.BINARY_OP_MOD,
2915
+ }
2916
+ return mapping.get(type(op))
2917
+
2918
+
2919
+ def _unary_op_to_ir(op: ast.unaryop) -> Optional[ir.UnaryOperator]:
2920
+ """Convert Python unary operator to IR UnaryOperator."""
2921
+ mapping = {
2922
+ ast.USub: ir.UnaryOperator.UNARY_OP_NEG,
2923
+ ast.Not: ir.UnaryOperator.UNARY_OP_NOT,
2924
+ }
2925
+ return mapping.get(type(op))
2926
+
2927
+
2928
+ def _cmp_op_to_ir(op: ast.cmpop) -> Optional[ir.BinaryOperator]:
2929
+ """Convert Python comparison operator to IR BinaryOperator."""
2930
+ mapping = {
2931
+ ast.Eq: ir.BinaryOperator.BINARY_OP_EQ,
2932
+ ast.NotEq: ir.BinaryOperator.BINARY_OP_NE,
2933
+ ast.Lt: ir.BinaryOperator.BINARY_OP_LT,
2934
+ ast.LtE: ir.BinaryOperator.BINARY_OP_LE,
2935
+ ast.Gt: ir.BinaryOperator.BINARY_OP_GT,
2936
+ ast.GtE: ir.BinaryOperator.BINARY_OP_GE,
2937
+ ast.In: ir.BinaryOperator.BINARY_OP_IN,
2938
+ ast.NotIn: ir.BinaryOperator.BINARY_OP_NOT_IN,
2939
+ }
2940
+ return mapping.get(type(op))
2941
+
2942
+
2943
+ def _bool_op_to_ir(op: ast.boolop) -> Optional[ir.BinaryOperator]:
2944
+ """Convert Python boolean operator to IR BinaryOperator."""
2945
+ mapping = {
2946
+ ast.And: ir.BinaryOperator.BINARY_OP_AND,
2947
+ ast.Or: ir.BinaryOperator.BINARY_OP_OR,
2948
+ }
2949
+ return mapping.get(type(op))
2950
+
2951
+
2952
+ def _get_func_name(func: ast.expr) -> Optional[str]:
2953
+ """Get function name from a Call's func attribute."""
2954
+ if isinstance(func, ast.Name):
2955
+ return func.id
2956
+ elif isinstance(func, ast.Attribute):
2957
+ # For method calls like obj.method, return full dotted name
2958
+ parts = []
2959
+ current = func
2960
+ while isinstance(current, ast.Attribute):
2961
+ parts.append(current.attr)
2962
+ current = current.value
2963
+ if isinstance(current, ast.Name):
2964
+ parts.append(current.id)
2965
+ return ".".join(reversed(parts))
2966
+ return None