rappel 0.4.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rappel might be problematic. Click here for more details.

rappel/ir_builder.py ADDED
@@ -0,0 +1,3146 @@
1
+ """
2
+ IR Builder - Converts Python workflow AST to Rappel IR (ast.proto).
3
+
4
+ This module parses Python workflow classes and produces the IR representation
5
+ that can be sent to the Rust runtime for execution.
6
+
7
+ The IR builder performs deep transformations to convert Python patterns into
8
+ valid Rappel IR structures. Each control flow body (try, for, if branches)
9
+ should have at most ONE action/function call. When bodies have multiple
10
+ action calls, they are wrapped into synthetic functions.
11
+
12
+ Transformations:
13
+ 1. **Try body wrapping**: Wraps multi-action try bodies into synthetic functions
14
+ 2. **For loop body wrapping**: Wraps multi-action for bodies into synthetic functions
15
+ 3. **If branch wrapping**: Wraps multi-action if/elif/else branches into synthetic functions
16
+ 4. **Exception handler wrapping**: Wraps multi-action handlers into synthetic functions
17
+
18
+ Validation:
19
+ The IR builder proactively detects unsupported Python patterns and raises
20
+ UnsupportedPatternError with clear recommendations for how to rewrite the code.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import ast
26
+ import copy
27
+ import inspect
28
+ import textwrap
29
+ from dataclasses import dataclass
30
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
31
+
32
+ from proto import ast_pb2 as ir
33
+
34
+
35
+ class UnsupportedPatternError(Exception):
36
+ """Raised when the IR builder encounters an unsupported Python pattern.
37
+
38
+ This error includes a recommendation for how to rewrite the code to use
39
+ supported patterns.
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ message: str,
45
+ recommendation: str,
46
+ line: Optional[int] = None,
47
+ col: Optional[int] = None,
48
+ ):
49
+ self.message = message
50
+ self.recommendation = recommendation
51
+ self.line = line
52
+ self.col = col
53
+
54
+ location = f" (line {line})" if line else ""
55
+ full_message = f"{message}{location}\n\nRecommendation: {recommendation}"
56
+ super().__init__(full_message)
57
+
58
+
59
+ # Recommendations for common unsupported patterns
60
+ RECOMMENDATIONS = {
61
+ "constructor_return": (
62
+ "Returning a class constructor (like MyModel(...)) directly is not supported.\n"
63
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
64
+ "Use an @action to create the object:\n\n"
65
+ " @action\n"
66
+ " async def build_result(items: list, count: int) -> MyResult:\n"
67
+ " return MyResult(items=items, count=count)\n\n"
68
+ " # In workflow:\n"
69
+ " return await build_result(items, count)"
70
+ ),
71
+ "constructor_assignment": (
72
+ "Assigning a class constructor result (like x = MyClass(...)) is not supported.\n"
73
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
74
+ "Use an @action to create the object:\n\n"
75
+ " @action\n"
76
+ " async def create_config(value: int) -> Config:\n"
77
+ " return Config(value=value)\n\n"
78
+ " # In workflow:\n"
79
+ " config = await create_config(value)"
80
+ ),
81
+ "non_action_call": (
82
+ "Calling a function that is not decorated with @action is not supported.\n"
83
+ "Only @action decorated functions can be awaited in workflow code.\n\n"
84
+ "Add the @action decorator to your function:\n\n"
85
+ " @action\n"
86
+ " async def my_function(x: int) -> int:\n"
87
+ " return x * 2"
88
+ ),
89
+ "sync_function_call": (
90
+ "Calling a synchronous function directly in workflow code is not supported.\n"
91
+ "All computation must happen inside @action decorated async functions.\n\n"
92
+ "Wrap your logic in an @action:\n\n"
93
+ " @action\n"
94
+ " async def compute(x: int) -> int:\n"
95
+ " return some_sync_function(x)"
96
+ ),
97
+ "method_call_non_self": (
98
+ "Calling methods on objects other than 'self' is not supported in workflow code.\n"
99
+ "Use an @action to perform method calls:\n\n"
100
+ " @action\n"
101
+ " async def call_method(obj: MyClass) -> Result:\n"
102
+ " return obj.some_method()"
103
+ ),
104
+ "builtin_call": (
105
+ "Calling built-in functions like len(), str(), int() directly is not supported.\n"
106
+ "Use an @action to perform these operations:\n\n"
107
+ " @action\n"
108
+ " async def get_length(items: list) -> int:\n"
109
+ " return len(items)"
110
+ ),
111
+ "fstring": (
112
+ "F-strings are not supported in workflow code because they require "
113
+ "runtime string interpolation.\n"
114
+ "Use an @action to perform string formatting:\n\n"
115
+ " @action\n"
116
+ " async def format_message(value: int) -> str:\n"
117
+ " return f'Result: {value}'"
118
+ ),
119
+ "delete": (
120
+ "The 'del' statement is not supported in workflow code.\n"
121
+ "Use an @action to perform mutations:\n\n"
122
+ " @action\n"
123
+ " async def remove_key(data: dict, key: str) -> dict:\n"
124
+ " del data[key]\n"
125
+ " return data"
126
+ ),
127
+ "while_loop": (
128
+ "While loops are not supported in workflow code because they can run "
129
+ "indefinitely.\n"
130
+ "Use a for loop with a fixed range, or restructure as recursive workflow calls."
131
+ ),
132
+ "with_statement": (
133
+ "Context managers (with statements) are not supported in workflow code.\n"
134
+ "Use an @action to handle resource management:\n\n"
135
+ " @action\n"
136
+ " async def read_file(path: str) -> str:\n"
137
+ " with open(path) as f:\n"
138
+ " return f.read()"
139
+ ),
140
+ "raise_statement": (
141
+ "The 'raise' statement is not supported directly in workflow code.\n"
142
+ "Use an @action that raises exceptions, or return error values."
143
+ ),
144
+ "assert_statement": (
145
+ "Assert statements are not supported in workflow code.\n"
146
+ "Use an @action for validation, or use if statements with explicit error handling."
147
+ ),
148
+ "lambda": (
149
+ "Lambda expressions are not supported in workflow code.\n"
150
+ "Use an @action to define the function logic."
151
+ ),
152
+ "list_comprehension": (
153
+ "List comprehensions are only supported when assigned directly to a variable "
154
+ "or inside asyncio.gather(*[...]).\n"
155
+ "For other cases, use a for loop or an @action."
156
+ ),
157
+ "dict_comprehension": (
158
+ "Dict comprehensions are only supported when assigned directly to a variable.\n"
159
+ "For other cases, use a for loop or an @action:\n\n"
160
+ " @action\n"
161
+ " async def build_dict(items: list) -> dict:\n"
162
+ " return {k: v for k, v in items}"
163
+ ),
164
+ "set_comprehension": (
165
+ "Set comprehensions are not supported in workflow code.\nUse an @action to build sets."
166
+ ),
167
+ "generator": (
168
+ "Generator expressions are not supported in workflow code.\n"
169
+ "Use a list or an @action instead."
170
+ ),
171
+ "walrus": (
172
+ "The walrus operator (:=) is not supported in workflow code.\n"
173
+ "Use separate assignment statements instead."
174
+ ),
175
+ "match": (
176
+ "Match statements are not supported in workflow code.\nUse if/elif/else chains instead."
177
+ ),
178
+ "gather_variable_spread": (
179
+ "Spreading a variable in asyncio.gather() is not supported because it requires "
180
+ "data flow analysis to determine the contents.\n"
181
+ "Use a list comprehension directly in gather:\n\n"
182
+ " # Instead of:\n"
183
+ " tasks = []\n"
184
+ " for i in range(count):\n"
185
+ " tasks.append(process(value=i))\n"
186
+ " results = await asyncio.gather(*tasks)\n\n"
187
+ " # Use:\n"
188
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
189
+ ),
190
+ "for_loop_append_pattern": (
191
+ "Building a task list in a for loop then spreading in asyncio.gather() is not "
192
+ "supported.\n"
193
+ "Use a list comprehension directly in gather:\n\n"
194
+ " # Instead of:\n"
195
+ " tasks = []\n"
196
+ " for i in range(count):\n"
197
+ " tasks.append(process(value=i))\n"
198
+ " results = await asyncio.gather(*tasks)\n\n"
199
+ " # Use:\n"
200
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
201
+ ),
202
+ "global_statement": (
203
+ "Global statements are not supported in workflow code.\n"
204
+ "Use workflow state or pass values explicitly."
205
+ ),
206
+ "nonlocal_statement": (
207
+ "Nonlocal statements are not supported in workflow code.\n"
208
+ "Use explicit parameter passing instead."
209
+ ),
210
+ "import_statement": (
211
+ "Import statements inside workflow run() are not supported.\n"
212
+ "Place imports at the module level."
213
+ ),
214
+ "class_def": (
215
+ "Class definitions inside workflow run() are not supported.\n"
216
+ "Define classes at the module level."
217
+ ),
218
+ "function_def": (
219
+ "Nested function definitions inside workflow run() are not supported.\n"
220
+ "Define functions at the module level or use @action."
221
+ ),
222
+ "yield_statement": (
223
+ "Yield statements are not supported in workflow code.\n"
224
+ "Workflows must return a complete result, not generate values incrementally."
225
+ ),
226
+ }
227
+
228
+
229
+ if TYPE_CHECKING:
230
+ from .workflow import Workflow
231
+
232
+
233
+ @dataclass
234
+ class ActionDefinition:
235
+ """Definition of an action function."""
236
+
237
+ action_name: str
238
+ module_name: Optional[str]
239
+ signature: inspect.Signature
240
+
241
+
242
+ @dataclass
243
+ class TransformContext:
244
+ """Context for IR transformations."""
245
+
246
+ # Counter for generating unique function names
247
+ implicit_fn_counter: int = 0
248
+ # Implicit functions generated during transformation
249
+ implicit_functions: List[ir.FunctionDef] = None # type: ignore
250
+
251
+ def __post_init__(self) -> None:
252
+ if self.implicit_functions is None:
253
+ self.implicit_functions = []
254
+
255
+ def next_implicit_fn_name(self, prefix: str = "implicit") -> str:
256
+ """Generate a unique implicit function name."""
257
+ self.implicit_fn_counter += 1
258
+ return f"__{prefix}_{self.implicit_fn_counter}__"
259
+
260
+
261
+ def build_workflow_ir(workflow_cls: type["Workflow"]) -> ir.Program:
262
+ """Build an IR Program from a workflow class.
263
+
264
+ Args:
265
+ workflow_cls: The workflow class to convert.
266
+
267
+ Returns:
268
+ An IR Program proto message.
269
+ """
270
+ original_run = getattr(workflow_cls, "__workflow_run_impl__", None)
271
+ if original_run is None:
272
+ original_run = workflow_cls.__dict__.get("run")
273
+ if original_run is None:
274
+ raise ValueError(f"workflow {workflow_cls!r} missing run implementation")
275
+
276
+ module = inspect.getmodule(original_run)
277
+ if module is None:
278
+ raise ValueError(f"unable to locate module for workflow {workflow_cls!r}")
279
+
280
+ # Get the function source and parse it
281
+ function_source = textwrap.dedent(inspect.getsource(original_run))
282
+ tree = ast.parse(function_source)
283
+
284
+ # Discover actions in the module
285
+ action_defs = _discover_action_names(module)
286
+
287
+ # Discover imports for built-in detection (e.g., from asyncio import sleep)
288
+ imported_names = _discover_module_imports(module)
289
+
290
+ # Discover all async function names in the module (for non-action detection)
291
+ module_functions = _discover_module_functions(module)
292
+
293
+ # Discover Pydantic models and dataclasses that can be used in workflows
294
+ model_defs = _discover_model_definitions(module)
295
+
296
+ # Build the IR with transformation context
297
+ ctx = TransformContext()
298
+ builder = IRBuilder(action_defs, ctx, imported_names, module_functions, model_defs)
299
+ builder.visit(tree)
300
+
301
+ # Create the Program with the main function and any implicit functions
302
+ program = ir.Program()
303
+
304
+ # Add implicit functions first (they may be called by the main function)
305
+ for implicit_fn in ctx.implicit_functions:
306
+ program.functions.append(implicit_fn)
307
+
308
+ # Add the main function
309
+ if builder.function_def:
310
+ program.functions.append(builder.function_def)
311
+
312
+ return program
313
+
314
+
315
+ def _discover_action_names(module: Any) -> Dict[str, ActionDefinition]:
316
+ """Discover all @action decorated functions in a module."""
317
+ names: Dict[str, ActionDefinition] = {}
318
+ for attr_name in dir(module):
319
+ attr = getattr(module, attr_name)
320
+ action_name = getattr(attr, "__rappel_action_name__", None)
321
+ action_module = getattr(attr, "__rappel_action_module__", None)
322
+ if callable(attr) and action_name:
323
+ signature = inspect.signature(attr)
324
+ names[attr_name] = ActionDefinition(
325
+ action_name=action_name,
326
+ module_name=action_module or module.__name__,
327
+ signature=signature,
328
+ )
329
+ return names
330
+
331
+
332
+ def _discover_module_functions(module: Any) -> Set[str]:
333
+ """Discover all async function names defined in a module.
334
+
335
+ This is used to detect when users await functions in the same module
336
+ that are NOT decorated with @action.
337
+ """
338
+ function_names: Set[str] = set()
339
+ for attr_name in dir(module):
340
+ try:
341
+ attr = getattr(module, attr_name)
342
+ except AttributeError:
343
+ continue
344
+
345
+ # Only include functions defined in THIS module (not imported)
346
+ if not callable(attr):
347
+ continue
348
+ if not inspect.iscoroutinefunction(attr):
349
+ continue
350
+
351
+ # Check if the function is defined in this module
352
+ func_module = getattr(attr, "__module__", None)
353
+ if func_module == module.__name__:
354
+ function_names.add(attr_name)
355
+
356
+ return function_names
357
+
358
+
359
+ @dataclass
360
+ class ImportedName:
361
+ """Tracks an imported name and its source module."""
362
+
363
+ local_name: str # Name used in code (e.g., "sleep")
364
+ module: str # Source module (e.g., "asyncio")
365
+ original_name: str # Original name in source module (e.g., "sleep")
366
+
367
+
368
+ @dataclass
369
+ class ModelFieldDefinition:
370
+ """Definition of a field in a Pydantic model or dataclass."""
371
+
372
+ name: str
373
+ has_default: bool
374
+ default_value: Any = None # Only set if has_default is True
375
+
376
+
377
+ @dataclass
378
+ class ModelDefinition:
379
+ """Definition of a Pydantic model or dataclass that can be used in workflows.
380
+
381
+ These are data classes that can be instantiated in workflow code and will
382
+ be converted to dictionary expressions in the IR.
383
+ """
384
+
385
+ class_name: str
386
+ fields: Dict[str, ModelFieldDefinition]
387
+ is_pydantic: bool # True for Pydantic models, False for dataclasses
388
+
389
+
390
+ def _is_simple_pydantic_model(cls: type) -> bool:
391
+ """Check if a class is a simple Pydantic model without custom validators.
392
+
393
+ A simple Pydantic model:
394
+ - Inherits from pydantic.BaseModel
395
+ - Has no field_validator or model_validator decorators
396
+ - Has no custom __init__ method
397
+
398
+ Returns False if pydantic is not installed or cls is not a Pydantic model.
399
+ """
400
+ try:
401
+ from pydantic import BaseModel
402
+ except ImportError:
403
+ return False
404
+
405
+ if not isinstance(cls, type) or not issubclass(cls, BaseModel):
406
+ return False
407
+
408
+ # Check for validators - Pydantic v2 uses __pydantic_decorators__
409
+ decorators = getattr(cls, "__pydantic_decorators__", None)
410
+ if decorators is not None:
411
+ # Check for field validators
412
+ if hasattr(decorators, "field_validators") and decorators.field_validators:
413
+ return False
414
+ # Check for model validators
415
+ if hasattr(decorators, "model_validators") and decorators.model_validators:
416
+ return False
417
+
418
+ # Check for custom __init__ (not the one from BaseModel)
419
+ if "__init__" in cls.__dict__:
420
+ return False
421
+
422
+ return True
423
+
424
+
425
+ def _is_simple_dataclass(cls: type) -> bool:
426
+ """Check if a class is a simple dataclass without custom logic.
427
+
428
+ A simple dataclass:
429
+ - Is decorated with @dataclass
430
+ - Has no custom __init__ method (uses the generated one)
431
+ - Has no __post_init__ method
432
+
433
+ Returns False if cls is not a dataclass.
434
+ """
435
+ import dataclasses
436
+
437
+ if not dataclasses.is_dataclass(cls):
438
+ return False
439
+
440
+ # Check for __post_init__ which could have custom logic
441
+ if hasattr(cls, "__post_init__") and "__post_init__" in cls.__dict__:
442
+ return False
443
+
444
+ # Dataclasses generate __init__, so check if there's a custom one
445
+ # that overrides it (unlikely but possible)
446
+ # The dataclass decorator sets __init__ so we can't easily detect override
447
+ # We'll trust that dataclasses without __post_init__ are simple
448
+
449
+ return True
450
+
451
+
452
+ def _get_pydantic_model_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
453
+ """Extract field definitions from a Pydantic model."""
454
+ try:
455
+ from pydantic import BaseModel
456
+ from pydantic.fields import FieldInfo
457
+ except ImportError:
458
+ return {}
459
+
460
+ if not issubclass(cls, BaseModel):
461
+ return {}
462
+
463
+ fields: Dict[str, ModelFieldDefinition] = {}
464
+
465
+ # Pydantic v2 uses model_fields
466
+ model_fields = getattr(cls, "model_fields", {})
467
+ for field_name, field_info in model_fields.items():
468
+ has_default = False
469
+ default_value = None
470
+
471
+ if isinstance(field_info, FieldInfo):
472
+ # Check if field has a default value
473
+ # PydanticUndefined means no default
474
+ from pydantic_core import PydanticUndefined
475
+
476
+ if field_info.default is not PydanticUndefined:
477
+ has_default = True
478
+ default_value = field_info.default
479
+ elif field_info.default_factory is not None:
480
+ # We can't serialize factory functions, so treat as no default
481
+ has_default = False
482
+
483
+ fields[field_name] = ModelFieldDefinition(
484
+ name=field_name,
485
+ has_default=has_default,
486
+ default_value=default_value,
487
+ )
488
+
489
+ return fields
490
+
491
+
492
+ def _get_dataclass_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
493
+ """Extract field definitions from a dataclass."""
494
+ import dataclasses
495
+
496
+ if not dataclasses.is_dataclass(cls):
497
+ return {}
498
+
499
+ fields: Dict[str, ModelFieldDefinition] = {}
500
+ for field in dataclasses.fields(cls):
501
+ has_default = False
502
+ default_value = None
503
+
504
+ if field.default is not dataclasses.MISSING:
505
+ has_default = True
506
+ default_value = field.default
507
+ elif field.default_factory is not dataclasses.MISSING:
508
+ # We can't serialize factory functions, so treat as no default
509
+ has_default = False
510
+
511
+ fields[field.name] = ModelFieldDefinition(
512
+ name=field.name,
513
+ has_default=has_default,
514
+ default_value=default_value,
515
+ )
516
+
517
+ return fields
518
+
519
+
520
+ def _discover_model_definitions(module: Any) -> Dict[str, ModelDefinition]:
521
+ """Discover all Pydantic models and dataclasses that can be used in workflows.
522
+
523
+ Only discovers "simple" models without custom validators or __post_init__.
524
+ """
525
+ models: Dict[str, ModelDefinition] = {}
526
+
527
+ for attr_name in dir(module):
528
+ try:
529
+ attr = getattr(module, attr_name)
530
+ except AttributeError:
531
+ continue
532
+
533
+ if not isinstance(attr, type):
534
+ continue
535
+
536
+ # Check if this class is defined in this module or imported
537
+ # We want to include both, as models might be imported
538
+ if _is_simple_pydantic_model(attr):
539
+ fields = _get_pydantic_model_fields(attr)
540
+ models[attr_name] = ModelDefinition(
541
+ class_name=attr_name,
542
+ fields=fields,
543
+ is_pydantic=True,
544
+ )
545
+ elif _is_simple_dataclass(attr):
546
+ fields = _get_dataclass_fields(attr)
547
+ models[attr_name] = ModelDefinition(
548
+ class_name=attr_name,
549
+ fields=fields,
550
+ is_pydantic=False,
551
+ )
552
+
553
+ return models
554
+
555
+
556
+ def _discover_module_imports(module: Any) -> Dict[str, ImportedName]:
557
+ """Discover imports in a module by parsing its source.
558
+
559
+ Tracks imports like:
560
+ - from asyncio import sleep -> {"sleep": ImportedName("sleep", "asyncio", "sleep")}
561
+ - from asyncio import sleep as s -> {"s": ImportedName("s", "asyncio", "sleep")}
562
+ """
563
+ imported: Dict[str, ImportedName] = {}
564
+
565
+ try:
566
+ source = inspect.getsource(module)
567
+ tree = ast.parse(source)
568
+ except (OSError, TypeError):
569
+ # Can't get source (e.g., built-in module)
570
+ return imported
571
+
572
+ for node in ast.walk(tree):
573
+ if isinstance(node, ast.ImportFrom) and node.module:
574
+ for alias in node.names:
575
+ local_name = alias.asname if alias.asname else alias.name
576
+ imported[local_name] = ImportedName(
577
+ local_name=local_name,
578
+ module=node.module,
579
+ original_name=alias.name,
580
+ )
581
+
582
+ return imported
583
+
584
+
585
+ class IRBuilder(ast.NodeVisitor):
586
+ """Builds IR from Python AST with deep transformations."""
587
+
588
+ def __init__(
589
+ self,
590
+ action_defs: Dict[str, ActionDefinition],
591
+ ctx: TransformContext,
592
+ imported_names: Optional[Dict[str, ImportedName]] = None,
593
+ module_functions: Optional[Set[str]] = None,
594
+ model_defs: Optional[Dict[str, ModelDefinition]] = None,
595
+ ):
596
+ self._action_defs = action_defs
597
+ self._ctx = ctx
598
+ self._imported_names = imported_names or {}
599
+ self._module_functions = module_functions or set()
600
+ self._model_defs = model_defs or {}
601
+ self.function_def: Optional[ir.FunctionDef] = None
602
+ self._statements: List[ir.Statement] = []
603
+
604
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
605
+ """Visit a function definition (the workflow's run method)."""
606
+ # Extract inputs from function parameters (skip 'self')
607
+ inputs: List[str] = []
608
+ for arg in node.args.args[1:]: # Skip 'self'
609
+ inputs.append(arg.arg)
610
+
611
+ # Create the function definition
612
+ self.function_def = ir.FunctionDef(
613
+ name=node.name,
614
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
615
+ span=_make_span(node),
616
+ )
617
+
618
+ # Visit the body - _visit_statement now returns a list
619
+ self._statements = []
620
+ for stmt in node.body:
621
+ ir_stmts = self._visit_statement(stmt)
622
+ self._statements.extend(ir_stmts)
623
+
624
+ # Set the body
625
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
626
+
627
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
628
+ """Visit an async function definition (the workflow's run method)."""
629
+ # Handle async the same way as sync for IR building
630
+ inputs: List[str] = []
631
+ for arg in node.args.args[1:]: # Skip 'self'
632
+ inputs.append(arg.arg)
633
+
634
+ self.function_def = ir.FunctionDef(
635
+ name=node.name,
636
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
637
+ span=_make_span(node),
638
+ )
639
+
640
+ self._statements = []
641
+ for stmt in node.body:
642
+ ir_stmts = self._visit_statement(stmt)
643
+ self._statements.extend(ir_stmts)
644
+
645
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
646
+
647
+ def _visit_statement(self, node: ast.stmt) -> List[ir.Statement]:
648
+ """Convert a Python statement to IR Statement(s).
649
+
650
+ Returns a list because some transformations (like try block hoisting)
651
+ may expand a single Python statement into multiple IR statements.
652
+
653
+ Raises UnsupportedPatternError for unsupported statement types.
654
+ """
655
+ if isinstance(node, ast.Assign):
656
+ dict_expanded = self._expand_dict_comprehension_assignment(node)
657
+ if dict_expanded is not None:
658
+ return dict_expanded
659
+ expanded = self._expand_list_comprehension_assignment(node)
660
+ if expanded is not None:
661
+ return expanded
662
+ result = self._visit_assign(node)
663
+ return [result] if result else []
664
+ elif isinstance(node, ast.Expr):
665
+ result = self._visit_expr_stmt(node)
666
+ return [result] if result else []
667
+ elif isinstance(node, ast.For):
668
+ return self._visit_for(node)
669
+ elif isinstance(node, ast.If):
670
+ result = self._visit_if(node)
671
+ return [result] if result else []
672
+ elif isinstance(node, ast.Try):
673
+ return self._visit_try(node)
674
+ elif isinstance(node, ast.Return):
675
+ return self._visit_return(node)
676
+ elif isinstance(node, ast.AugAssign):
677
+ result = self._visit_aug_assign(node)
678
+ return [result] if result else []
679
+ elif isinstance(node, ast.Pass):
680
+ # Pass statements are fine, they just don't produce IR
681
+ return []
682
+
683
+ # Check for unsupported statement types
684
+ self._check_unsupported_statement(node)
685
+
686
+ return []
687
+
688
+ def _check_unsupported_statement(self, node: ast.stmt) -> None:
689
+ """Check for unsupported statement types and raise descriptive errors."""
690
+ line = getattr(node, "lineno", None)
691
+ col = getattr(node, "col_offset", None)
692
+
693
+ if isinstance(node, ast.While):
694
+ raise UnsupportedPatternError(
695
+ "While loops are not supported",
696
+ RECOMMENDATIONS["while_loop"],
697
+ line=line,
698
+ col=col,
699
+ )
700
+ elif isinstance(node, (ast.With, ast.AsyncWith)):
701
+ raise UnsupportedPatternError(
702
+ "Context managers (with statements) are not supported",
703
+ RECOMMENDATIONS["with_statement"],
704
+ line=line,
705
+ col=col,
706
+ )
707
+ elif isinstance(node, ast.Raise):
708
+ raise UnsupportedPatternError(
709
+ "The 'raise' statement is not supported",
710
+ RECOMMENDATIONS["raise_statement"],
711
+ line=line,
712
+ col=col,
713
+ )
714
+ elif isinstance(node, ast.Assert):
715
+ raise UnsupportedPatternError(
716
+ "Assert statements are not supported",
717
+ RECOMMENDATIONS["assert_statement"],
718
+ line=line,
719
+ col=col,
720
+ )
721
+ elif isinstance(node, ast.Delete):
722
+ raise UnsupportedPatternError(
723
+ "The 'del' statement is not supported",
724
+ RECOMMENDATIONS["delete"],
725
+ line=line,
726
+ col=col,
727
+ )
728
+ elif isinstance(node, ast.Global):
729
+ raise UnsupportedPatternError(
730
+ "Global statements are not supported",
731
+ RECOMMENDATIONS["global_statement"],
732
+ line=line,
733
+ col=col,
734
+ )
735
+ elif isinstance(node, ast.Nonlocal):
736
+ raise UnsupportedPatternError(
737
+ "Nonlocal statements are not supported",
738
+ RECOMMENDATIONS["nonlocal_statement"],
739
+ line=line,
740
+ col=col,
741
+ )
742
+ elif isinstance(node, (ast.Import, ast.ImportFrom)):
743
+ raise UnsupportedPatternError(
744
+ "Import statements inside workflow run() are not supported",
745
+ RECOMMENDATIONS["import_statement"],
746
+ line=line,
747
+ col=col,
748
+ )
749
+ elif isinstance(node, ast.ClassDef):
750
+ raise UnsupportedPatternError(
751
+ "Class definitions inside workflow run() are not supported",
752
+ RECOMMENDATIONS["class_def"],
753
+ line=line,
754
+ col=col,
755
+ )
756
+ elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
757
+ raise UnsupportedPatternError(
758
+ "Nested function definitions are not supported",
759
+ RECOMMENDATIONS["function_def"],
760
+ line=line,
761
+ col=col,
762
+ )
763
+ elif hasattr(ast, "Match") and isinstance(node, ast.Match):
764
+ raise UnsupportedPatternError(
765
+ "Match statements are not supported",
766
+ RECOMMENDATIONS["match"],
767
+ line=line,
768
+ col=col,
769
+ )
770
+
771
+ def _expand_list_comprehension_assignment(
772
+ self, node: ast.Assign
773
+ ) -> Optional[List[ir.Statement]]:
774
+ """Expand a list comprehension assignment into loop-based statements.
775
+
776
+ Example:
777
+ active_users = [user for user in users if user.active]
778
+
779
+ Becomes:
780
+ active_users = []
781
+ for user in users:
782
+ if user.active:
783
+ active_users = active_users + [user]
784
+ """
785
+ if not isinstance(node.value, ast.ListComp):
786
+ return None
787
+
788
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
789
+ line = getattr(node, "lineno", None)
790
+ col = getattr(node, "col_offset", None)
791
+ raise UnsupportedPatternError(
792
+ "List comprehension assignments must target a single variable",
793
+ "Assign the comprehension to a simple variable like `results = [x for x in items]`",
794
+ line=line,
795
+ col=col,
796
+ )
797
+
798
+ listcomp = node.value
799
+ if len(listcomp.generators) != 1:
800
+ line = getattr(listcomp, "lineno", None)
801
+ col = getattr(listcomp, "col_offset", None)
802
+ raise UnsupportedPatternError(
803
+ "List comprehensions with multiple generators are not supported",
804
+ "Use nested for loops instead of combining multiple generators in one comprehension",
805
+ line=line,
806
+ col=col,
807
+ )
808
+
809
+ gen = listcomp.generators[0]
810
+ if gen.is_async:
811
+ line = getattr(listcomp, "lineno", None)
812
+ col = getattr(listcomp, "col_offset", None)
813
+ raise UnsupportedPatternError(
814
+ "Async list comprehensions are not supported",
815
+ "Rewrite using an explicit async for loop",
816
+ line=line,
817
+ col=col,
818
+ )
819
+
820
+ target_name = node.targets[0].id
821
+
822
+ # Initialize the accumulator list: active_users = []
823
+ init_assign_ast = ast.Assign(
824
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
825
+ value=ast.List(elts=[], ctx=ast.Load()),
826
+ type_comment=None,
827
+ )
828
+ ast.copy_location(init_assign_ast, node)
829
+ ast.fix_missing_locations(init_assign_ast)
830
+
831
+ def _make_append_assignment(value_expr: ast.expr) -> ast.Assign:
832
+ append_assign = ast.Assign(
833
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
834
+ value=ast.BinOp(
835
+ left=ast.Name(id=target_name, ctx=ast.Load()),
836
+ op=ast.Add(),
837
+ right=ast.List(elts=[copy.deepcopy(value_expr)], ctx=ast.Load()),
838
+ ),
839
+ type_comment=None,
840
+ )
841
+ ast.copy_location(append_assign, node.value)
842
+ ast.fix_missing_locations(append_assign)
843
+ return append_assign
844
+
845
+ append_statements: List[ast.stmt] = []
846
+ if isinstance(listcomp.elt, ast.IfExp):
847
+ then_assign = _make_append_assignment(listcomp.elt.body)
848
+ else_assign = _make_append_assignment(listcomp.elt.orelse)
849
+ branch_if = ast.If(
850
+ test=copy.deepcopy(listcomp.elt.test),
851
+ body=[then_assign],
852
+ orelse=[else_assign],
853
+ )
854
+ ast.copy_location(branch_if, listcomp.elt)
855
+ ast.fix_missing_locations(branch_if)
856
+ append_statements.append(branch_if)
857
+ else:
858
+ append_statements.append(_make_append_assignment(listcomp.elt))
859
+
860
+ loop_body: List[ast.stmt] = append_statements
861
+ if gen.ifs:
862
+ condition: ast.expr
863
+ if len(gen.ifs) == 1:
864
+ condition = copy.deepcopy(gen.ifs[0])
865
+ else:
866
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
867
+ ast.copy_location(condition, gen.ifs[0])
868
+ if_stmt = ast.If(test=condition, body=append_statements, orelse=[])
869
+ ast.copy_location(if_stmt, gen.ifs[0])
870
+ ast.fix_missing_locations(if_stmt)
871
+ loop_body = [if_stmt]
872
+
873
+ loop_ast = ast.For(
874
+ target=copy.deepcopy(gen.target),
875
+ iter=copy.deepcopy(gen.iter),
876
+ body=loop_body,
877
+ orelse=[],
878
+ type_comment=None,
879
+ )
880
+ ast.copy_location(loop_ast, node)
881
+ ast.fix_missing_locations(loop_ast)
882
+
883
+ statements: List[ir.Statement] = []
884
+ init_stmt = self._visit_assign(init_assign_ast)
885
+ if init_stmt:
886
+ statements.append(init_stmt)
887
+ statements.extend(self._visit_for(loop_ast))
888
+
889
+ return statements
890
+
891
+ def _expand_dict_comprehension_assignment(
892
+ self, node: ast.Assign
893
+ ) -> Optional[List[ir.Statement]]:
894
+ """Expand a dict comprehension assignment into loop-based statements."""
895
+ if not isinstance(node.value, ast.DictComp):
896
+ return None
897
+
898
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
899
+ line = getattr(node, "lineno", None)
900
+ col = getattr(node, "col_offset", None)
901
+ raise UnsupportedPatternError(
902
+ "Dict comprehension assignments must target a single variable",
903
+ "Assign the comprehension to a simple variable like `result = {k: v for k, v in pairs}`",
904
+ line=line,
905
+ col=col,
906
+ )
907
+
908
+ dictcomp = node.value
909
+ if len(dictcomp.generators) != 1:
910
+ line = getattr(dictcomp, "lineno", None)
911
+ col = getattr(dictcomp, "col_offset", None)
912
+ raise UnsupportedPatternError(
913
+ "Dict comprehensions with multiple generators are not supported",
914
+ "Use nested for loops instead of combining multiple generators in one comprehension",
915
+ line=line,
916
+ col=col,
917
+ )
918
+
919
+ gen = dictcomp.generators[0]
920
+ if gen.is_async:
921
+ line = getattr(dictcomp, "lineno", None)
922
+ col = getattr(dictcomp, "col_offset", None)
923
+ raise UnsupportedPatternError(
924
+ "Async dict comprehensions are not supported",
925
+ "Rewrite using an explicit async for loop",
926
+ line=line,
927
+ col=col,
928
+ )
929
+
930
+ target_name = node.targets[0].id
931
+
932
+ # Initialize accumulator: result = {}
933
+ init_assign_ast = ast.Assign(
934
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
935
+ value=ast.Dict(keys=[], values=[]),
936
+ type_comment=None,
937
+ )
938
+ ast.copy_location(init_assign_ast, node)
939
+ ast.fix_missing_locations(init_assign_ast)
940
+
941
+ # result[key] = value
942
+ subscript_target = ast.Subscript(
943
+ value=ast.Name(id=target_name, ctx=ast.Load()),
944
+ slice=copy.deepcopy(dictcomp.key),
945
+ ctx=ast.Store(),
946
+ )
947
+ append_assign_ast = ast.Assign(
948
+ targets=[subscript_target],
949
+ value=copy.deepcopy(dictcomp.value),
950
+ type_comment=None,
951
+ )
952
+ ast.copy_location(append_assign_ast, node.value)
953
+ ast.fix_missing_locations(append_assign_ast)
954
+
955
+ loop_body: List[ast.stmt] = []
956
+ if gen.ifs:
957
+ condition: ast.expr
958
+ if len(gen.ifs) == 1:
959
+ condition = copy.deepcopy(gen.ifs[0])
960
+ else:
961
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
962
+ ast.copy_location(condition, gen.ifs[0])
963
+ if_stmt = ast.If(test=condition, body=[append_assign_ast], orelse=[])
964
+ ast.copy_location(if_stmt, gen.ifs[0])
965
+ ast.fix_missing_locations(if_stmt)
966
+ loop_body.append(if_stmt)
967
+ else:
968
+ loop_body.append(append_assign_ast)
969
+
970
+ loop_ast = ast.For(
971
+ target=copy.deepcopy(gen.target),
972
+ iter=copy.deepcopy(gen.iter),
973
+ body=loop_body,
974
+ orelse=[],
975
+ type_comment=None,
976
+ )
977
+ ast.copy_location(loop_ast, node)
978
+ ast.fix_missing_locations(loop_ast)
979
+
980
+ statements: List[ir.Statement] = []
981
+ init_stmt = self._visit_assign(init_assign_ast)
982
+ if init_stmt:
983
+ statements.append(init_stmt)
984
+ statements.extend(self._visit_for(loop_ast))
985
+
986
+ return statements
987
+
988
+ def _visit_assign(self, node: ast.Assign) -> Optional[ir.Statement]:
989
+ """Convert assignment to IR.
990
+
991
+ All assignments with targets use the Assignment statement type.
992
+ This provides uniform unpacking support for:
993
+ - Action calls: a, b = @get_pair()
994
+ - Parallel blocks: a, b = parallel: @x() @y()
995
+ - Regular expressions: a, b = some_list
996
+
997
+ Raises UnsupportedPatternError for:
998
+ - Constructor calls: x = MyClass(...)
999
+ - Non-action await: x = await some_func()
1000
+ """
1001
+ stmt = ir.Statement(span=_make_span(node))
1002
+ targets = self._get_assign_targets(node.targets)
1003
+
1004
+ # Check for Pydantic model or dataclass constructor calls
1005
+ # These are converted to dict expressions
1006
+ model_name = self._is_model_constructor(node.value)
1007
+ if model_name and isinstance(node.value, ast.Call):
1008
+ value_expr = self._convert_model_constructor_to_dict(node.value, model_name)
1009
+ assign = ir.Assignment(targets=targets, value=value_expr)
1010
+ stmt.assignment.CopyFrom(assign)
1011
+ return stmt
1012
+
1013
+ # Check for constructor calls in assignment (e.g., x = MyModel(...))
1014
+ # This must come AFTER the model constructor check since models are allowed
1015
+ self._check_constructor_in_assignment(node.value)
1016
+
1017
+ # Check for asyncio.gather() - convert to parallel or spread expression
1018
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1019
+ gather_call = node.value.value
1020
+ if self._is_asyncio_gather_call(gather_call):
1021
+ gather_result = self._convert_asyncio_gather(gather_call)
1022
+ if gather_result is not None:
1023
+ if isinstance(gather_result, ir.ParallelExpr):
1024
+ value = ir.Expr(parallel_expr=gather_result, span=_make_span(node))
1025
+ else:
1026
+ # SpreadExpr
1027
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1028
+ assign = ir.Assignment(targets=targets, value=value)
1029
+ stmt.assignment.CopyFrom(assign)
1030
+ return stmt
1031
+
1032
+ # Check if this is an action call - wrap in Assignment for uniform unpacking
1033
+ action_call = self._extract_action_call(node.value)
1034
+ if action_call:
1035
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1036
+ assign = ir.Assignment(targets=targets, value=value)
1037
+ stmt.assignment.CopyFrom(assign)
1038
+ return stmt
1039
+
1040
+ # Regular assignment (variables, literals, expressions)
1041
+ value_expr = _expr_to_ir(node.value)
1042
+ if value_expr:
1043
+ assign = ir.Assignment(targets=targets, value=value_expr)
1044
+ stmt.assignment.CopyFrom(assign)
1045
+ return stmt
1046
+
1047
+ return None
1048
+
1049
+ def _visit_expr_stmt(self, node: ast.Expr) -> Optional[ir.Statement]:
1050
+ """Convert expression statement to IR (side effect only, no assignment)."""
1051
+ stmt = ir.Statement(span=_make_span(node))
1052
+
1053
+ # Check for asyncio.gather() - convert to parallel block statement (side effect)
1054
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1055
+ gather_call = node.value.value
1056
+ if self._is_asyncio_gather_call(gather_call):
1057
+ gather_result = self._convert_asyncio_gather(gather_call)
1058
+ if gather_result is not None:
1059
+ if isinstance(gather_result, ir.ParallelExpr):
1060
+ # Side effect only - use ParallelBlock statement
1061
+ parallel = ir.ParallelBlock()
1062
+ parallel.calls.extend(gather_result.calls)
1063
+ stmt.parallel_block.CopyFrom(parallel)
1064
+ return stmt
1065
+ else:
1066
+ # SpreadExpr as side effect - wrap in assignment with no targets
1067
+ # This handles: await asyncio.gather(*[action(x) for x in items])
1068
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1069
+ assign = ir.Assignment(targets=[], value=value)
1070
+ stmt.assignment.CopyFrom(assign)
1071
+ return stmt
1072
+
1073
+ # Check if this is an action call (side effect only)
1074
+ action_call = self._extract_action_call(node.value)
1075
+ if action_call:
1076
+ stmt.action_call.CopyFrom(action_call)
1077
+ return stmt
1078
+
1079
+ # Convert list.append(x) to list = list + [x]
1080
+ # This makes the mutation explicit so data flows correctly through the DAG
1081
+ if isinstance(node.value, ast.Call):
1082
+ call = node.value
1083
+ if (
1084
+ isinstance(call.func, ast.Attribute)
1085
+ and call.func.attr == "append"
1086
+ and isinstance(call.func.value, ast.Name)
1087
+ and len(call.args) == 1
1088
+ ):
1089
+ list_name = call.func.value.id
1090
+ append_value = call.args[0]
1091
+ # Create: list = list + [value]
1092
+ list_var = ir.Expr(variable=ir.Variable(name=list_name), span=_make_span(node))
1093
+ value_expr = _expr_to_ir(append_value)
1094
+ if value_expr:
1095
+ # Create [value] as a list literal
1096
+ list_literal = ir.Expr(
1097
+ list=ir.ListExpr(elements=[value_expr]), span=_make_span(node)
1098
+ )
1099
+ # Create list + [value]
1100
+ concat_expr = ir.Expr(
1101
+ binary_op=ir.BinaryOp(
1102
+ op=ir.BinaryOperator.BINARY_OP_ADD, left=list_var, right=list_literal
1103
+ ),
1104
+ span=_make_span(node),
1105
+ )
1106
+ assign = ir.Assignment(targets=[list_name], value=concat_expr)
1107
+ stmt.assignment.CopyFrom(assign)
1108
+ return stmt
1109
+
1110
+ # Regular expression
1111
+ expr = _expr_to_ir(node.value)
1112
+ if expr:
1113
+ stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=expr))
1114
+ return stmt
1115
+
1116
+ return None
1117
+
1118
+ def _visit_for(self, node: ast.For) -> List[ir.Statement]:
1119
+ """Convert for loop to IR with body wrapping transformation.
1120
+
1121
+ If the for loop body has multiple action/function calls, we wrap them
1122
+ into a synthetic function and replace the body with a single call.
1123
+
1124
+ For loops that modify out-of-scope variables (accumulators) are detected
1125
+ and those variables are set as targets on the SingleCallBody. This enables
1126
+ the runtime to properly aggregate results into those variables.
1127
+
1128
+ Supported accumulator patterns:
1129
+ 1. List append: results.append(value)
1130
+ 2. Dict subscript: result[key] = value
1131
+ 3. List concatenation: results = results + [value]
1132
+ 4. Counter increment: count = count + 1
1133
+
1134
+ Python:
1135
+ for item in items:
1136
+ a = await step_one(item)
1137
+ b = await step_two(a)
1138
+
1139
+ Becomes IR equivalent of:
1140
+ fn __for_body_1__(item):
1141
+ a = @step_one(item=item)
1142
+ b = @step_two(a=a)
1143
+ return b
1144
+
1145
+ for item in items:
1146
+ __for_body_1__(item=item)
1147
+ """
1148
+ # Get loop variables
1149
+ loop_vars: List[str] = []
1150
+ if isinstance(node.target, ast.Name):
1151
+ loop_vars.append(node.target.id)
1152
+ elif isinstance(node.target, ast.Tuple):
1153
+ for elt in node.target.elts:
1154
+ if isinstance(elt, ast.Name):
1155
+ loop_vars.append(elt.id)
1156
+
1157
+ # Get iterable
1158
+ iterable = _expr_to_ir(node.iter)
1159
+ if not iterable:
1160
+ return []
1161
+
1162
+ # Collect variables defined within the loop body (in-scope)
1163
+ in_scope_vars = set(loop_vars)
1164
+
1165
+ # Build body statements (recursively transforms nested structures)
1166
+ body_stmts: List[ir.Statement] = []
1167
+ for body_node in node.body:
1168
+ stmts = self._visit_statement(body_node)
1169
+ body_stmts.extend(stmts)
1170
+ # Track variables defined by assignments in this iteration
1171
+ for s in stmts:
1172
+ if s.HasField("assignment"):
1173
+ in_scope_vars.update(s.assignment.targets)
1174
+
1175
+ # Detect all out-of-scope variable modifications
1176
+ # These are variables modified in the loop body but defined outside it
1177
+ modified_vars = self._detect_accumulator_targets(body_stmts, in_scope_vars)
1178
+
1179
+ # ALWAYS wrap for loop body into a synthetic function for variable isolation.
1180
+ # Variables flow in/out explicitly through function parameters and return values.
1181
+ body_stmts = self._wrap_body_as_function(
1182
+ body_stmts, "for_body", node, inputs=loop_vars, modified_vars=modified_vars
1183
+ )
1184
+
1185
+ # Convert to SingleCallBody (now contains just the synthetic function call)
1186
+ stmt = ir.Statement(span=_make_span(node))
1187
+ single_call_body = self._stmts_to_single_call_body(body_stmts, _make_span(node))
1188
+
1189
+ for_loop = ir.ForLoop(
1190
+ loop_vars=loop_vars,
1191
+ iterable=iterable,
1192
+ body=single_call_body,
1193
+ )
1194
+ stmt.for_loop.CopyFrom(for_loop)
1195
+ return [stmt]
1196
+
1197
+ def _detect_accumulator_targets(
1198
+ self, stmts: List[ir.Statement], in_scope_vars: set
1199
+ ) -> List[str]:
1200
+ """Detect out-of-scope variable modifications in for loop body.
1201
+
1202
+ Scans statements for patterns that modify variables defined outside the loop.
1203
+ Returns a list of accumulator variable names that should be set as targets.
1204
+
1205
+ Supported patterns:
1206
+ 1. List append: results.append(value) -> "results"
1207
+ 2. Dict subscript: result[key] = value -> "result"
1208
+ 3. List/set update methods: results.extend(...), results.add(...) -> "results"
1209
+
1210
+ Note: Patterns like `results = results + [x]` and `count = count + 1` create
1211
+ new assignments which are tracked via in_scope_vars and don't need special
1212
+ detection here - they're handled by the regular assignment target logic.
1213
+ """
1214
+ accumulators: List[str] = []
1215
+ seen: set = set()
1216
+
1217
+ for stmt in stmts:
1218
+ var_name = self._extract_accumulator_from_stmt(stmt, in_scope_vars)
1219
+ if var_name and var_name not in seen:
1220
+ accumulators.append(var_name)
1221
+ seen.add(var_name)
1222
+
1223
+ # Check conditionals for accumulator targets in branch bodies
1224
+ if stmt.HasField("conditional"):
1225
+ cond = stmt.conditional
1226
+ branch_bodies = [cond.if_branch.body] if cond.HasField("if_branch") else []
1227
+ branch_bodies.extend(
1228
+ branch.body for branch in cond.elif_branches if branch.HasField("body")
1229
+ )
1230
+ if cond.HasField("else_branch"):
1231
+ branch_bodies.append(cond.else_branch.body)
1232
+
1233
+ for body in branch_bodies:
1234
+ for target in body.targets:
1235
+ if target not in in_scope_vars and target not in seen:
1236
+ accumulators.append(target)
1237
+ seen.add(target)
1238
+
1239
+ if body.statements:
1240
+ for var in self._detect_accumulator_targets(
1241
+ list(body.statements), in_scope_vars
1242
+ ):
1243
+ if var not in seen:
1244
+ accumulators.append(var)
1245
+ seen.add(var)
1246
+
1247
+ return accumulators
1248
+
1249
+ def _extract_accumulator_from_stmt(
1250
+ self, stmt: ir.Statement, in_scope_vars: set
1251
+ ) -> Optional[str]:
1252
+ """Extract accumulator variable name from a single statement.
1253
+
1254
+ Returns the variable name if this statement modifies an out-of-scope variable,
1255
+ None otherwise.
1256
+ """
1257
+ # Pattern 1: Method calls like list.append(), dict.update(), set.add()
1258
+ if stmt.HasField("expr_stmt"):
1259
+ expr = stmt.expr_stmt.expr
1260
+ if expr.HasField("function_call"):
1261
+ fn_name = expr.function_call.name
1262
+ # Check for mutating method calls: x.append, x.extend, x.add, x.update, etc.
1263
+ mutating_methods = {
1264
+ ".append",
1265
+ ".extend",
1266
+ ".add",
1267
+ ".update",
1268
+ ".insert",
1269
+ ".pop",
1270
+ ".remove",
1271
+ ".clear",
1272
+ }
1273
+ for method in mutating_methods:
1274
+ if fn_name.endswith(method):
1275
+ var_name = fn_name[: len(fn_name) - len(method)]
1276
+ # Only return if it's an out-of-scope variable
1277
+ if var_name and var_name not in in_scope_vars:
1278
+ return var_name
1279
+
1280
+ # Pattern 2: Subscript assignment like dict[key] = value
1281
+ if stmt.HasField("assignment"):
1282
+ for target in stmt.assignment.targets:
1283
+ # Check if target is a subscript pattern (contains '[')
1284
+ if "[" in target:
1285
+ # Extract base variable name (before '[')
1286
+ var_name = target.split("[")[0]
1287
+ if var_name and var_name not in in_scope_vars:
1288
+ return var_name
1289
+
1290
+ # Pattern 3: Self-referential assignment like x = x + [y]
1291
+ # The target variable is used on the RHS, so it must come from outside.
1292
+ # Note: We don't check in_scope_vars here because the assignment itself
1293
+ # would have added the target to in_scope_vars, but it still needs its
1294
+ # previous value from outside the loop body.
1295
+ if stmt.HasField("assignment"):
1296
+ assign = stmt.assignment
1297
+ rhs_vars = self._collect_variables_from_expr(assign.value)
1298
+ for target in assign.targets:
1299
+ if target in rhs_vars:
1300
+ return target
1301
+
1302
+ return None
1303
+
1304
+ def _collect_variables_from_expr(self, expr: ir.Expr) -> set:
1305
+ """Recursively collect all variable names used in an expression."""
1306
+ vars_found: set = set()
1307
+
1308
+ if expr.HasField("variable"):
1309
+ vars_found.add(expr.variable.name)
1310
+ elif expr.HasField("binary_op"):
1311
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.left))
1312
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.right))
1313
+ elif expr.HasField("unary_op"):
1314
+ vars_found.update(self._collect_variables_from_expr(expr.unary_op.operand))
1315
+ elif expr.HasField("list"):
1316
+ for elem in expr.list.elements:
1317
+ vars_found.update(self._collect_variables_from_expr(elem))
1318
+ elif expr.HasField("dict"):
1319
+ for key in expr.dict.keys:
1320
+ vars_found.update(self._collect_variables_from_expr(key))
1321
+ for val in expr.dict.values:
1322
+ vars_found.update(self._collect_variables_from_expr(val))
1323
+ elif expr.HasField("index"):
1324
+ vars_found.update(self._collect_variables_from_expr(expr.index.value))
1325
+ vars_found.update(self._collect_variables_from_expr(expr.index.index))
1326
+ elif expr.HasField("dot"):
1327
+ vars_found.update(self._collect_variables_from_expr(expr.dot.object))
1328
+ elif expr.HasField("function_call"):
1329
+ for kwarg in expr.function_call.kwargs:
1330
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1331
+ elif expr.HasField("action_call"):
1332
+ for kwarg in expr.action_call.kwargs:
1333
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1334
+
1335
+ return vars_found
1336
+
1337
+ def _visit_if(self, node: ast.If) -> Optional[ir.Statement]:
1338
+ """Convert if statement to IR conditional with branch wrapping.
1339
+
1340
+ If any branch has multiple action calls, we wrap it into a synthetic
1341
+ function to ensure each branch has at most one call.
1342
+
1343
+ Out-of-scope variable modifications (like list.append()) are detected
1344
+ and the modified variables are passed in/out of the synthetic function.
1345
+
1346
+ Python:
1347
+ if condition:
1348
+ a = await action_a()
1349
+ b = await action_b(a)
1350
+ else:
1351
+ c = await action_c()
1352
+
1353
+ Becomes IR equivalent of:
1354
+ fn __if_then_1__():
1355
+ a = @action_a()
1356
+ b = @action_b(a=a)
1357
+ return b
1358
+
1359
+ if condition:
1360
+ __if_then_1__()
1361
+ else:
1362
+ @action_c()
1363
+ """
1364
+ stmt = ir.Statement(span=_make_span(node))
1365
+
1366
+ # Build if branch
1367
+ condition = _expr_to_ir(node.test)
1368
+ if not condition:
1369
+ return None
1370
+
1371
+ body_stmts: List[ir.Statement] = []
1372
+ for body_node in node.body:
1373
+ stmts = self._visit_statement(body_node)
1374
+ body_stmts.extend(stmts)
1375
+
1376
+ # ALWAYS wrap if branch body for variable isolation
1377
+ in_scope_vars = self._collect_assigned_vars(body_stmts)
1378
+ modified_vars = self._detect_accumulator_targets(body_stmts, in_scope_vars)
1379
+ body_stmts = self._wrap_body_as_function(
1380
+ body_stmts, "if_then", node, modified_vars=modified_vars
1381
+ )
1382
+
1383
+ if_branch = ir.IfBranch(
1384
+ condition=condition,
1385
+ body=self._stmts_to_single_call_body(body_stmts, _make_span(node)),
1386
+ span=_make_span(node),
1387
+ )
1388
+
1389
+ conditional = ir.Conditional(if_branch=if_branch)
1390
+
1391
+ # Handle elif/else chains
1392
+ current = node
1393
+ while current.orelse:
1394
+ if len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If):
1395
+ # elif branch
1396
+ elif_node = current.orelse[0]
1397
+ elif_condition = _expr_to_ir(elif_node.test)
1398
+ if elif_condition:
1399
+ elif_body: List[ir.Statement] = []
1400
+ for body_node in elif_node.body:
1401
+ stmts = self._visit_statement(body_node)
1402
+ elif_body.extend(stmts)
1403
+
1404
+ # ALWAYS wrap elif body for variable isolation
1405
+ in_scope_vars = self._collect_assigned_vars(elif_body)
1406
+ modified_vars = self._detect_accumulator_targets(elif_body, in_scope_vars)
1407
+ elif_body = self._wrap_body_as_function(
1408
+ elif_body, "if_elif", elif_node, modified_vars=modified_vars
1409
+ )
1410
+
1411
+ elif_branch = ir.ElifBranch(
1412
+ condition=elif_condition,
1413
+ body=self._stmts_to_single_call_body(elif_body, _make_span(elif_node)),
1414
+ span=_make_span(elif_node),
1415
+ )
1416
+ conditional.elif_branches.append(elif_branch)
1417
+ current = elif_node
1418
+ else:
1419
+ # else branch
1420
+ else_body: List[ir.Statement] = []
1421
+ for else_node in current.orelse:
1422
+ stmts = self._visit_statement(else_node)
1423
+ else_body.extend(stmts)
1424
+
1425
+ # ALWAYS wrap else body for variable isolation
1426
+ in_scope_vars = self._collect_assigned_vars(else_body)
1427
+ modified_vars = self._detect_accumulator_targets(else_body, in_scope_vars)
1428
+ else_body = self._wrap_body_as_function(
1429
+ else_body, "if_else", current.orelse[0], modified_vars=modified_vars
1430
+ )
1431
+
1432
+ else_branch = ir.ElseBranch(
1433
+ body=self._stmts_to_single_call_body(
1434
+ else_body, _make_span(current.orelse[0]) if current.orelse else ir.Span()
1435
+ ),
1436
+ span=_make_span(current.orelse[0]) if current.orelse else None,
1437
+ )
1438
+ conditional.else_branch.CopyFrom(else_branch)
1439
+ break
1440
+
1441
+ stmt.conditional.CopyFrom(conditional)
1442
+ return stmt
1443
+
1444
+ def _collect_assigned_vars(self, stmts: List[ir.Statement]) -> set:
1445
+ """Collect all variable names assigned in a list of statements."""
1446
+ assigned = set()
1447
+ for stmt in stmts:
1448
+ if stmt.HasField("assignment"):
1449
+ assigned.update(stmt.assignment.targets)
1450
+ return assigned
1451
+
1452
+ def _collect_assigned_vars_in_order(self, stmts: List[ir.Statement]) -> list[str]:
1453
+ """Collect assigned variable names in statement order (deduplicated)."""
1454
+ assigned: list[str] = []
1455
+ seen: set[str] = set()
1456
+
1457
+ for stmt in stmts:
1458
+ if stmt.HasField("assignment"):
1459
+ for target in stmt.assignment.targets:
1460
+ if target not in seen:
1461
+ seen.add(target)
1462
+ assigned.append(target)
1463
+
1464
+ if stmt.HasField("conditional"):
1465
+ cond = stmt.conditional
1466
+ if cond.HasField("if_branch") and cond.if_branch.HasField("body"):
1467
+ for target in self._collect_assigned_vars_in_order(
1468
+ list(cond.if_branch.body.statements)
1469
+ ):
1470
+ if target not in seen:
1471
+ seen.add(target)
1472
+ assigned.append(target)
1473
+ for elif_branch in cond.elif_branches:
1474
+ if elif_branch.HasField("body"):
1475
+ for target in self._collect_assigned_vars_in_order(
1476
+ list(elif_branch.body.statements)
1477
+ ):
1478
+ if target not in seen:
1479
+ seen.add(target)
1480
+ assigned.append(target)
1481
+ if cond.HasField("else_branch") and cond.else_branch.HasField("body"):
1482
+ for target in self._collect_assigned_vars_in_order(
1483
+ list(cond.else_branch.body.statements)
1484
+ ):
1485
+ if target not in seen:
1486
+ seen.add(target)
1487
+ assigned.append(target)
1488
+
1489
+ if stmt.HasField("for_loop") and stmt.for_loop.HasField("body"):
1490
+ for target in self._collect_assigned_vars_in_order(
1491
+ list(stmt.for_loop.body.statements)
1492
+ ):
1493
+ if target not in seen:
1494
+ seen.add(target)
1495
+ assigned.append(target)
1496
+
1497
+ if stmt.HasField("try_except"):
1498
+ try_body = stmt.try_except.try_body
1499
+ if try_body.HasField("span"):
1500
+ for target in self._collect_assigned_vars_in_order(list(try_body.statements)):
1501
+ if target not in seen:
1502
+ seen.add(target)
1503
+ assigned.append(target)
1504
+ for handler in stmt.try_except.handlers:
1505
+ if handler.HasField("body"):
1506
+ for target in self._collect_assigned_vars_in_order(
1507
+ list(handler.body.statements)
1508
+ ):
1509
+ if target not in seen:
1510
+ seen.add(target)
1511
+ assigned.append(target)
1512
+
1513
+ return assigned
1514
+
1515
+ def _collect_variables_from_single_call_body(self, body: ir.SingleCallBody) -> list[str]:
1516
+ vars_found: list[str] = []
1517
+ seen: set[str] = set()
1518
+
1519
+ if body.HasField("call"):
1520
+ call = body.call
1521
+ if call.HasField("action"):
1522
+ for kwarg in call.action.kwargs:
1523
+ for var in self._collect_variables_from_expr(kwarg.value):
1524
+ if var not in seen:
1525
+ seen.add(var)
1526
+ vars_found.append(var)
1527
+ elif call.HasField("function"):
1528
+ for kwarg in call.function.kwargs:
1529
+ for var in self._collect_variables_from_expr(kwarg.value):
1530
+ if var not in seen:
1531
+ seen.add(var)
1532
+ vars_found.append(var)
1533
+
1534
+ for stmt in body.statements:
1535
+ for var in self._collect_variables_from_statements([stmt]):
1536
+ if var not in seen:
1537
+ seen.add(var)
1538
+ vars_found.append(var)
1539
+
1540
+ return vars_found
1541
+
1542
+ def _collect_variables_from_statements(self, stmts: List[ir.Statement]) -> list[str]:
1543
+ """Collect variable references from statements in encounter order."""
1544
+ vars_found: list[str] = []
1545
+ seen: set[str] = set()
1546
+
1547
+ for stmt in stmts:
1548
+ if stmt.HasField("assignment") and stmt.assignment.HasField("value"):
1549
+ for var in self._collect_variables_from_expr(stmt.assignment.value):
1550
+ if var not in seen:
1551
+ seen.add(var)
1552
+ vars_found.append(var)
1553
+
1554
+ if stmt.HasField("return_stmt") and stmt.return_stmt.HasField("value"):
1555
+ for var in self._collect_variables_from_expr(stmt.return_stmt.value):
1556
+ if var not in seen:
1557
+ seen.add(var)
1558
+ vars_found.append(var)
1559
+
1560
+ if stmt.HasField("action_call"):
1561
+ expr = ir.Expr(action_call=stmt.action_call, span=stmt.span)
1562
+ for var in self._collect_variables_from_expr(expr):
1563
+ if var not in seen:
1564
+ seen.add(var)
1565
+ vars_found.append(var)
1566
+
1567
+ if stmt.HasField("expr_stmt"):
1568
+ for var in self._collect_variables_from_expr(stmt.expr_stmt.expr):
1569
+ if var not in seen:
1570
+ seen.add(var)
1571
+ vars_found.append(var)
1572
+
1573
+ if stmt.HasField("conditional"):
1574
+ cond = stmt.conditional
1575
+ if cond.HasField("if_branch"):
1576
+ if cond.if_branch.HasField("condition"):
1577
+ for var in self._collect_variables_from_expr(cond.if_branch.condition):
1578
+ if var not in seen:
1579
+ seen.add(var)
1580
+ vars_found.append(var)
1581
+ if cond.if_branch.HasField("body"):
1582
+ for var in self._collect_variables_from_single_call_body(
1583
+ cond.if_branch.body
1584
+ ):
1585
+ if var not in seen:
1586
+ seen.add(var)
1587
+ vars_found.append(var)
1588
+ for elif_branch in cond.elif_branches:
1589
+ if elif_branch.HasField("condition"):
1590
+ for var in self._collect_variables_from_expr(elif_branch.condition):
1591
+ if var not in seen:
1592
+ seen.add(var)
1593
+ vars_found.append(var)
1594
+ if elif_branch.HasField("body"):
1595
+ for var in self._collect_variables_from_single_call_body(elif_branch.body):
1596
+ if var not in seen:
1597
+ seen.add(var)
1598
+ vars_found.append(var)
1599
+ if cond.HasField("else_branch") and cond.else_branch.HasField("body"):
1600
+ for var in self._collect_variables_from_single_call_body(cond.else_branch.body):
1601
+ if var not in seen:
1602
+ seen.add(var)
1603
+ vars_found.append(var)
1604
+
1605
+ if stmt.HasField("for_loop"):
1606
+ fl = stmt.for_loop
1607
+ if fl.HasField("iterable"):
1608
+ for var in self._collect_variables_from_expr(fl.iterable):
1609
+ if var not in seen:
1610
+ seen.add(var)
1611
+ vars_found.append(var)
1612
+ if fl.HasField("body"):
1613
+ for var in self._collect_variables_from_single_call_body(fl.body):
1614
+ if var not in seen:
1615
+ seen.add(var)
1616
+ vars_found.append(var)
1617
+
1618
+ if stmt.HasField("try_except"):
1619
+ te = stmt.try_except
1620
+ if te.HasField("try_body"):
1621
+ for var in self._collect_variables_from_single_call_body(te.try_body):
1622
+ if var not in seen:
1623
+ seen.add(var)
1624
+ vars_found.append(var)
1625
+ for handler in te.handlers:
1626
+ if handler.HasField("body"):
1627
+ for var in self._collect_variables_from_single_call_body(handler.body):
1628
+ if var not in seen:
1629
+ seen.add(var)
1630
+ vars_found.append(var)
1631
+
1632
+ if stmt.HasField("parallel_block"):
1633
+ for call in stmt.parallel_block.calls:
1634
+ if call.HasField("action"):
1635
+ for kwarg in call.action.kwargs:
1636
+ for var in self._collect_variables_from_expr(kwarg.value):
1637
+ if var not in seen:
1638
+ seen.add(var)
1639
+ vars_found.append(var)
1640
+ elif call.HasField("function"):
1641
+ for kwarg in call.function.kwargs:
1642
+ for var in self._collect_variables_from_expr(kwarg.value):
1643
+ if var not in seen:
1644
+ seen.add(var)
1645
+ vars_found.append(var)
1646
+
1647
+ if stmt.HasField("spread_action"):
1648
+ spread = stmt.spread_action
1649
+ if spread.HasField("collection"):
1650
+ for var in self._collect_variables_from_expr(spread.collection):
1651
+ if var not in seen:
1652
+ seen.add(var)
1653
+ vars_found.append(var)
1654
+ if spread.HasField("action"):
1655
+ for kwarg in spread.action.kwargs:
1656
+ for var in self._collect_variables_from_expr(kwarg.value):
1657
+ if var not in seen:
1658
+ seen.add(var)
1659
+ vars_found.append(var)
1660
+
1661
+ return vars_found
1662
+
1663
+ def _visit_try(self, node: ast.Try) -> List[ir.Statement]:
1664
+ """Convert try/except to IR with body wrapping transformation.
1665
+
1666
+ If the try body has multiple action calls, we wrap the entire body
1667
+ into a synthetic function, preserving exact semantics.
1668
+
1669
+ Python:
1670
+ try:
1671
+ a = await setup_action()
1672
+ b = await risky_action(a)
1673
+ return f"success:{b}"
1674
+ except SomeError:
1675
+ ...
1676
+
1677
+ Becomes IR equivalent of:
1678
+ fn __try_body_1__():
1679
+ a = @setup_action()
1680
+ b = @risky_action(a=a)
1681
+ return f"success:{b}"
1682
+
1683
+ try:
1684
+ __try_body_1__()
1685
+ except SomeError:
1686
+ ...
1687
+ """
1688
+ # Build try body statements (recursively transforms nested structures)
1689
+ try_body: List[ir.Statement] = []
1690
+ for body_node in node.body:
1691
+ stmts = self._visit_statement(body_node)
1692
+ try_body.extend(stmts)
1693
+
1694
+ # ALWAYS wrap try body for variable isolation
1695
+ assigned_vars_ordered = self._collect_assigned_vars_in_order(try_body)
1696
+ assigned_vars_set = set(assigned_vars_ordered)
1697
+ free_vars = [
1698
+ var
1699
+ for var in self._collect_variables_from_statements(try_body)
1700
+ if var not in assigned_vars_set
1701
+ ]
1702
+ modified_vars = self._detect_accumulator_targets(try_body, assigned_vars_set)
1703
+
1704
+ # Inputs need free variables plus any accumulator-style mutations.
1705
+ try_inputs = []
1706
+ for var in free_vars + modified_vars:
1707
+ if var not in try_inputs:
1708
+ try_inputs.append(var)
1709
+
1710
+ # Outputs include all assigned variables plus accumulator targets.
1711
+ try_outputs: list[str] = []
1712
+ for var in assigned_vars_ordered + modified_vars:
1713
+ if var not in try_outputs:
1714
+ try_outputs.append(var)
1715
+
1716
+ try_body = self._wrap_body_as_function(
1717
+ try_body,
1718
+ "try_body",
1719
+ node,
1720
+ inputs=try_inputs,
1721
+ modified_vars=try_outputs,
1722
+ )
1723
+
1724
+ # Build exception handlers (with wrapping if needed)
1725
+ handlers: List[ir.ExceptHandler] = []
1726
+ for handler in node.handlers:
1727
+ exception_types: List[str] = []
1728
+ if handler.type:
1729
+ if isinstance(handler.type, ast.Name):
1730
+ exception_types.append(handler.type.id)
1731
+ elif isinstance(handler.type, ast.Tuple):
1732
+ for elt in handler.type.elts:
1733
+ if isinstance(elt, ast.Name):
1734
+ exception_types.append(elt.id)
1735
+
1736
+ # Build handler body (recursively transforms nested structures)
1737
+ handler_body: List[ir.Statement] = []
1738
+ for handler_node in handler.body:
1739
+ stmts = self._visit_statement(handler_node)
1740
+ handler_body.extend(stmts)
1741
+
1742
+ # ALWAYS wrap handler body for variable isolation
1743
+ assigned_vars_ordered = self._collect_assigned_vars_in_order(handler_body)
1744
+ assigned_vars_set = set(assigned_vars_ordered)
1745
+ free_vars = [
1746
+ var
1747
+ for var in self._collect_variables_from_statements(handler_body)
1748
+ if var not in assigned_vars_set
1749
+ ]
1750
+ modified_vars = self._detect_accumulator_targets(handler_body, assigned_vars_set)
1751
+
1752
+ handler_inputs: list[str] = []
1753
+ for var in free_vars + modified_vars:
1754
+ if var not in handler_inputs:
1755
+ handler_inputs.append(var)
1756
+
1757
+ handler_outputs: list[str] = []
1758
+ for var in assigned_vars_ordered + modified_vars:
1759
+ if var not in handler_outputs:
1760
+ handler_outputs.append(var)
1761
+
1762
+ handler_body = self._wrap_body_as_function(
1763
+ handler_body,
1764
+ "except_handler",
1765
+ node,
1766
+ inputs=handler_inputs,
1767
+ modified_vars=handler_outputs,
1768
+ )
1769
+
1770
+ except_handler = ir.ExceptHandler(
1771
+ exception_types=exception_types,
1772
+ body=self._stmts_to_single_call_body(handler_body, _make_span(handler)),
1773
+ span=_make_span(handler),
1774
+ )
1775
+ handlers.append(except_handler)
1776
+
1777
+ # Build the try/except statement
1778
+ try_stmt = ir.Statement(span=_make_span(node))
1779
+ try_except = ir.TryExcept(
1780
+ try_body=self._stmts_to_single_call_body(try_body, _make_span(node)),
1781
+ handlers=handlers,
1782
+ )
1783
+ try_stmt.try_except.CopyFrom(try_except)
1784
+
1785
+ return [try_stmt]
1786
+
1787
+ def _count_calls(self, stmts: List[ir.Statement]) -> int:
1788
+ """Count action calls and function calls in statements.
1789
+
1790
+ Both action calls and function calls (including synthetic functions)
1791
+ count toward the limit of one call per control flow body.
1792
+ """
1793
+ count = 0
1794
+ for stmt in stmts:
1795
+ if stmt.HasField("action_call"):
1796
+ count += 1
1797
+ elif stmt.HasField("assignment"):
1798
+ # Check if assignment value is an action call or function call
1799
+ if stmt.assignment.value.HasField("action_call"):
1800
+ count += 1
1801
+ elif stmt.assignment.value.HasField("function_call"):
1802
+ count += 1
1803
+ elif stmt.HasField("expr_stmt"):
1804
+ # Check if expression is a function call
1805
+ if stmt.expr_stmt.expr.HasField("function_call"):
1806
+ count += 1
1807
+ return count
1808
+
1809
+ def _stmts_to_single_call_body(
1810
+ self, stmts: List[ir.Statement], span: ir.Span
1811
+ ) -> ir.SingleCallBody:
1812
+ """Convert statements to SingleCallBody.
1813
+
1814
+ Can contain EITHER:
1815
+ 1. A single action or function call (with optional target)
1816
+ 2. Pure data statements (no calls)
1817
+ """
1818
+ body = ir.SingleCallBody(span=span)
1819
+
1820
+ # Look for a single call in the statements
1821
+ for stmt in stmts:
1822
+ if stmt.HasField("action_call"):
1823
+ # ActionCall as a statement has no target (side-effect only)
1824
+ action = stmt.action_call
1825
+ call = ir.Call()
1826
+ call.action.CopyFrom(action)
1827
+ body.call.CopyFrom(call)
1828
+ return body
1829
+ elif stmt.HasField("assignment"):
1830
+ # Check if assignment value is an action call or function call
1831
+ if stmt.assignment.value.HasField("action_call"):
1832
+ action = stmt.assignment.value.action_call
1833
+ # Copy all targets for tuple unpacking support
1834
+ body.targets.extend(stmt.assignment.targets)
1835
+ call = ir.Call()
1836
+ call.action.CopyFrom(action)
1837
+ body.call.CopyFrom(call)
1838
+ return body
1839
+ elif stmt.assignment.value.HasField("function_call"):
1840
+ fn_call = stmt.assignment.value.function_call
1841
+ # Copy all targets for tuple unpacking support
1842
+ body.targets.extend(stmt.assignment.targets)
1843
+ call = ir.Call()
1844
+ call.function.CopyFrom(fn_call)
1845
+ body.call.CopyFrom(call)
1846
+ return body
1847
+ elif stmt.HasField("expr_stmt") and stmt.expr_stmt.expr.HasField("function_call"):
1848
+ fn_call = stmt.expr_stmt.expr.function_call
1849
+ call = ir.Call()
1850
+ call.function.CopyFrom(fn_call)
1851
+ body.call.CopyFrom(call)
1852
+ return body
1853
+
1854
+ # No call found - this is a pure data body
1855
+ # Add all statements as pure data
1856
+ body.statements.extend(stmts)
1857
+ return body
1858
+
1859
+ def _wrap_body_as_function(
1860
+ self,
1861
+ body: List[ir.Statement],
1862
+ prefix: str,
1863
+ node: ast.AST,
1864
+ inputs: Optional[List[str]] = None,
1865
+ modified_vars: Optional[List[str]] = None,
1866
+ ) -> List[ir.Statement]:
1867
+ """Wrap a body with multiple calls into a synthetic function.
1868
+
1869
+ Args:
1870
+ body: The statements to wrap
1871
+ prefix: Name prefix for the synthetic function
1872
+ node: AST node for span information
1873
+ inputs: Variables to pass as inputs (e.g., loop variables)
1874
+ modified_vars: Out-of-scope variables modified in the body.
1875
+ These are added as inputs AND returned as outputs,
1876
+ enabling functional transformation of external state.
1877
+
1878
+ Returns a list containing a single function call statement (or assignment
1879
+ if modified_vars are present).
1880
+ """
1881
+ fn_name = self._ctx.next_implicit_fn_name(prefix)
1882
+ fn_inputs = list(inputs or [])
1883
+
1884
+ # Add modified variables as inputs (they need to be passed in)
1885
+ modified_vars = modified_vars or []
1886
+ for var in modified_vars:
1887
+ if var not in fn_inputs:
1888
+ fn_inputs.append(var)
1889
+
1890
+ # If there are modified variables, add a return statement for them
1891
+ wrapped_body = list(body)
1892
+ if modified_vars:
1893
+ # Create return statement: return (var1, var2, ...) or return var1
1894
+ if len(modified_vars) == 1:
1895
+ return_expr = ir.Expr(
1896
+ variable=ir.Variable(name=modified_vars[0]),
1897
+ span=_make_span(node),
1898
+ )
1899
+ else:
1900
+ # Return as list (tuples are represented as lists in IR)
1901
+ return_expr = ir.Expr(
1902
+ list=ir.ListExpr(
1903
+ elements=[ir.Expr(variable=ir.Variable(name=var)) for var in modified_vars]
1904
+ ),
1905
+ span=_make_span(node),
1906
+ )
1907
+ return_stmt = ir.Statement(span=_make_span(node))
1908
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=return_expr))
1909
+ wrapped_body.append(return_stmt)
1910
+
1911
+ # Create the synthetic function
1912
+ implicit_fn = ir.FunctionDef(
1913
+ name=fn_name,
1914
+ io=ir.IoDecl(inputs=fn_inputs, outputs=modified_vars),
1915
+ body=ir.Block(statements=wrapped_body),
1916
+ span=_make_span(node),
1917
+ )
1918
+ self._ctx.implicit_functions.append(implicit_fn)
1919
+
1920
+ # Create a function call expression
1921
+ kwargs = [
1922
+ ir.Kwarg(name=var, value=ir.Expr(variable=ir.Variable(name=var))) for var in fn_inputs
1923
+ ]
1924
+ fn_call_expr = ir.Expr(
1925
+ function_call=ir.FunctionCall(name=fn_name, kwargs=kwargs),
1926
+ span=_make_span(node),
1927
+ )
1928
+
1929
+ # If there are modified variables, create an assignment statement
1930
+ # so the returned values are assigned back to the variables
1931
+ call_stmt = ir.Statement(span=_make_span(node))
1932
+ if modified_vars:
1933
+ # Create assignment: var1, var2 = fn(...) or var1 = fn(...)
1934
+ assign = ir.Assignment(value=fn_call_expr)
1935
+ assign.targets.extend(modified_vars)
1936
+ call_stmt.assignment.CopyFrom(assign)
1937
+ else:
1938
+ call_stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=fn_call_expr))
1939
+
1940
+ return [call_stmt]
1941
+
1942
+ def _visit_return(self, node: ast.Return) -> List[ir.Statement]:
1943
+ """Convert return statement to IR.
1944
+
1945
+ Return statements should only contain variables or literals, not action calls.
1946
+ If the return contains an action call, we normalize it:
1947
+ return await action()
1948
+ becomes:
1949
+ _return_tmp = await action()
1950
+ return _return_tmp
1951
+
1952
+ Constructor calls (like return MyModel(...)) are not supported and will
1953
+ raise an error with a recommendation to use an @action instead.
1954
+ """
1955
+ if node.value:
1956
+ # Check for constructor calls in return (e.g., return MyModel(...))
1957
+ self._check_constructor_in_return(node.value)
1958
+
1959
+ # Check if returning an action call - normalize to assignment + return
1960
+ action_call = self._extract_action_call(node.value)
1961
+ if action_call:
1962
+ # Create a temporary variable for the action result
1963
+ tmp_var = "_return_tmp"
1964
+
1965
+ # Create assignment: _return_tmp = await action()
1966
+ assign_stmt = ir.Statement(span=_make_span(node))
1967
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1968
+ assign = ir.Assignment(targets=[tmp_var], value=value)
1969
+ assign_stmt.assignment.CopyFrom(assign)
1970
+
1971
+ # Create return: return _return_tmp
1972
+ return_stmt = ir.Statement(span=_make_span(node))
1973
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
1974
+ ret = ir.ReturnStmt(value=var_expr)
1975
+ return_stmt.return_stmt.CopyFrom(ret)
1976
+
1977
+ return [assign_stmt, return_stmt]
1978
+
1979
+ # Regular return with expression (variable, literal, etc.)
1980
+ expr = _expr_to_ir(node.value)
1981
+ if expr:
1982
+ stmt = ir.Statement(span=_make_span(node))
1983
+ return_stmt = ir.ReturnStmt(value=expr)
1984
+ stmt.return_stmt.CopyFrom(return_stmt)
1985
+ return [stmt]
1986
+
1987
+ # Return with no value
1988
+ stmt = ir.Statement(span=_make_span(node))
1989
+ stmt.return_stmt.CopyFrom(ir.ReturnStmt())
1990
+ return [stmt]
1991
+
1992
+ def _visit_aug_assign(self, node: ast.AugAssign) -> Optional[ir.Statement]:
1993
+ """Convert augmented assignment (+=, -=, etc.) to IR."""
1994
+ # For now, we can represent this as a regular assignment with binary op
1995
+ # target op= value -> target = target op value
1996
+ stmt = ir.Statement(span=_make_span(node))
1997
+
1998
+ targets: List[str] = []
1999
+ if isinstance(node.target, ast.Name):
2000
+ targets.append(node.target.id)
2001
+
2002
+ left = _expr_to_ir(node.target)
2003
+ right = _expr_to_ir(node.value)
2004
+ if left and right:
2005
+ op = _bin_op_to_ir(node.op)
2006
+ if op:
2007
+ binary = ir.BinaryOp(left=left, op=op, right=right)
2008
+ value = ir.Expr(binary_op=binary)
2009
+ assign = ir.Assignment(targets=targets, value=value)
2010
+ stmt.assignment.CopyFrom(assign)
2011
+ return stmt
2012
+
2013
+ return None
2014
+
2015
+ def _check_constructor_in_return(self, node: ast.expr) -> None:
2016
+ """Check for constructor calls in return statements.
2017
+
2018
+ Raises UnsupportedPatternError if the return value is a class instantiation
2019
+ like: return MyModel(field=value)
2020
+
2021
+ This is not supported because the workflow IR cannot serialize arbitrary
2022
+ object instantiation. Users should use an @action to create objects.
2023
+ """
2024
+ # Skip if it's an await (action call) - those are fine
2025
+ if isinstance(node, ast.Await):
2026
+ return
2027
+
2028
+ # Check for direct Call that looks like a constructor
2029
+ if isinstance(node, ast.Call):
2030
+ func_name = self._get_constructor_name(node.func)
2031
+ if func_name and self._looks_like_constructor(func_name, node):
2032
+ line = getattr(node, "lineno", None)
2033
+ col = getattr(node, "col_offset", None)
2034
+ raise UnsupportedPatternError(
2035
+ f"Returning constructor call '{func_name}(...)' is not supported",
2036
+ RECOMMENDATIONS["constructor_return"],
2037
+ line=line,
2038
+ col=col,
2039
+ )
2040
+
2041
+ def _check_constructor_in_assignment(self, node: ast.expr) -> None:
2042
+ """Check for constructor calls in assignments.
2043
+
2044
+ Raises UnsupportedPatternError if the assignment value is a class instantiation
2045
+ like: result = MyModel(field=value)
2046
+
2047
+ This is not supported because the workflow IR cannot serialize arbitrary
2048
+ object instantiation. Users should use an @action to create objects.
2049
+ """
2050
+ # Skip if it's an await (action call) - those are fine
2051
+ if isinstance(node, ast.Await):
2052
+ return
2053
+
2054
+ # Check for direct Call that looks like a constructor
2055
+ if isinstance(node, ast.Call):
2056
+ func_name = self._get_constructor_name(node.func)
2057
+ if func_name and self._looks_like_constructor(func_name, node):
2058
+ line = getattr(node, "lineno", None)
2059
+ col = getattr(node, "col_offset", None)
2060
+ raise UnsupportedPatternError(
2061
+ f"Assigning constructor call '{func_name}(...)' is not supported",
2062
+ RECOMMENDATIONS["constructor_assignment"],
2063
+ line=line,
2064
+ col=col,
2065
+ )
2066
+
2067
+ def _get_constructor_name(self, func: ast.expr) -> Optional[str]:
2068
+ """Get the name from a function expression if it looks like a constructor."""
2069
+ if isinstance(func, ast.Name):
2070
+ return func.id
2071
+ elif isinstance(func, ast.Attribute):
2072
+ return func.attr
2073
+ return None
2074
+
2075
+ def _looks_like_constructor(self, func_name: str, call: ast.Call) -> bool:
2076
+ """Check if a function call looks like a class constructor.
2077
+
2078
+ A constructor is identified by:
2079
+ 1. Name starts with uppercase (PEP8 convention for classes)
2080
+ 2. It's not a known action
2081
+ 3. It's not a known builtin like String operations
2082
+ 4. It's not a known Pydantic model or dataclass (those are allowed)
2083
+
2084
+ This is a heuristic - we can't perfectly distinguish constructors
2085
+ from functions without full type information.
2086
+ """
2087
+ # Check if first letter is uppercase (class naming convention)
2088
+ if not func_name or not func_name[0].isupper():
2089
+ return False
2090
+
2091
+ # If it's a known action, it's not a constructor
2092
+ if func_name in self._action_defs:
2093
+ return False
2094
+
2095
+ # If it's a known Pydantic model or dataclass, allow it
2096
+ # (it will be converted to a dict expression)
2097
+ if func_name in self._model_defs:
2098
+ return False
2099
+
2100
+ # Common builtins that start with uppercase but aren't constructors
2101
+ # (these are rarely used in workflow code but let's be safe)
2102
+ builtin_exceptions = {"True", "False", "None", "Ellipsis"}
2103
+ if func_name in builtin_exceptions:
2104
+ return False
2105
+
2106
+ return True
2107
+
2108
+ def _is_model_constructor(self, node: ast.expr) -> Optional[str]:
2109
+ """Check if an expression is a Pydantic model or dataclass constructor call.
2110
+
2111
+ Returns the model name if it is, None otherwise.
2112
+ """
2113
+ if not isinstance(node, ast.Call):
2114
+ return None
2115
+
2116
+ func_name = self._get_constructor_name(node.func)
2117
+ if func_name and func_name in self._model_defs:
2118
+ return func_name
2119
+
2120
+ return None
2121
+
2122
+ def _convert_model_constructor_to_dict(self, node: ast.Call, model_name: str) -> ir.Expr:
2123
+ """Convert a Pydantic model or dataclass constructor call to a dict expression.
2124
+
2125
+ For example:
2126
+ MyModel(field1=value1, field2=value2)
2127
+ becomes:
2128
+ {"field1": value1, "field2": value2}
2129
+
2130
+ Default values from the model definition are included for fields not
2131
+ explicitly provided in the constructor call.
2132
+ """
2133
+ model_def = self._model_defs[model_name]
2134
+ entries: List[ir.DictEntry] = []
2135
+
2136
+ # Track which fields were explicitly provided
2137
+ provided_fields: Set[str] = set()
2138
+
2139
+ # First, add all explicitly provided kwargs
2140
+ for kw in node.keywords:
2141
+ if kw.arg is None:
2142
+ # **kwargs expansion - not supported
2143
+ line = getattr(node, "lineno", None)
2144
+ col = getattr(node, "col_offset", None)
2145
+ raise UnsupportedPatternError(
2146
+ f"Model constructor '{model_name}' with **kwargs is not supported",
2147
+ "Use explicit keyword arguments instead of **kwargs.",
2148
+ line=line,
2149
+ col=col,
2150
+ )
2151
+
2152
+ provided_fields.add(kw.arg)
2153
+ key_expr = ir.Expr()
2154
+ key_literal = ir.Literal()
2155
+ key_literal.string_value = kw.arg
2156
+ key_expr.literal.CopyFrom(key_literal)
2157
+
2158
+ value_expr = _expr_to_ir(kw.value)
2159
+ if value_expr is None:
2160
+ # If we can't convert the value, we need to raise an error
2161
+ line = getattr(node, "lineno", None)
2162
+ col = getattr(node, "col_offset", None)
2163
+ raise UnsupportedPatternError(
2164
+ f"Cannot convert value for field '{kw.arg}' in '{model_name}'",
2165
+ "Use simpler expressions (literals, variables, dicts, lists).",
2166
+ line=line,
2167
+ col=col,
2168
+ )
2169
+
2170
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2171
+
2172
+ # Handle positional arguments (dataclasses support this)
2173
+ if node.args:
2174
+ # For dataclasses, positional args map to fields in order
2175
+ field_names = list(model_def.fields.keys())
2176
+ for i, arg in enumerate(node.args):
2177
+ if i >= len(field_names):
2178
+ line = getattr(node, "lineno", None)
2179
+ col = getattr(node, "col_offset", None)
2180
+ raise UnsupportedPatternError(
2181
+ f"Too many positional arguments for '{model_name}'",
2182
+ "Use keyword arguments for clarity.",
2183
+ line=line,
2184
+ col=col,
2185
+ )
2186
+
2187
+ field_name = field_names[i]
2188
+ provided_fields.add(field_name)
2189
+
2190
+ key_expr = ir.Expr()
2191
+ key_literal = ir.Literal()
2192
+ key_literal.string_value = field_name
2193
+ key_expr.literal.CopyFrom(key_literal)
2194
+
2195
+ value_expr = _expr_to_ir(arg)
2196
+ if value_expr is None:
2197
+ line = getattr(node, "lineno", None)
2198
+ col = getattr(node, "col_offset", None)
2199
+ raise UnsupportedPatternError(
2200
+ f"Cannot convert positional argument for field '{field_name}' in '{model_name}'",
2201
+ "Use simpler expressions (literals, variables, dicts, lists).",
2202
+ line=line,
2203
+ col=col,
2204
+ )
2205
+
2206
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2207
+
2208
+ # Add default values for fields not explicitly provided
2209
+ for field_name, field_def in model_def.fields.items():
2210
+ if field_name in provided_fields:
2211
+ continue
2212
+
2213
+ if field_def.has_default:
2214
+ key_expr = ir.Expr()
2215
+ key_literal = ir.Literal()
2216
+ key_literal.string_value = field_name
2217
+ key_expr.literal.CopyFrom(key_literal)
2218
+
2219
+ # Convert the default value to an IR literal
2220
+ default_literal = _constant_to_literal(field_def.default_value)
2221
+ if default_literal is None:
2222
+ # Can't serialize this default - skip it
2223
+ # (it's probably a complex object like a list factory)
2224
+ continue
2225
+
2226
+ value_expr = ir.Expr()
2227
+ value_expr.literal.CopyFrom(default_literal)
2228
+
2229
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2230
+
2231
+ result = ir.Expr(span=_make_span(node))
2232
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2233
+ return result
2234
+
2235
+ def _check_non_action_await(self, node: ast.Await) -> None:
2236
+ """Check if an await is for a non-action function.
2237
+
2238
+ Note: We can only reliably detect non-action awaits for functions defined
2239
+ in the same module. Actions imported from other modules will pass through
2240
+ and may fail at runtime if they're not actually actions.
2241
+
2242
+ For now, we only check against common builtins and known non-action patterns.
2243
+ A runtime check will catch functions that aren't registered actions.
2244
+ """
2245
+ awaited = node.value
2246
+ if not isinstance(awaited, ast.Call):
2247
+ return
2248
+
2249
+ # Skip special cases that are handled elsewhere
2250
+ if self._is_run_action_call(awaited):
2251
+ return
2252
+ if self._is_asyncio_sleep_call(awaited):
2253
+ return
2254
+ if self._is_asyncio_gather_call(awaited):
2255
+ return
2256
+
2257
+ # Get the function name
2258
+ func_name = None
2259
+ if isinstance(awaited.func, ast.Name):
2260
+ func_name = awaited.func.id
2261
+ elif isinstance(awaited.func, ast.Attribute):
2262
+ func_name = awaited.func.attr
2263
+
2264
+ if not func_name:
2265
+ return
2266
+
2267
+ # Only raise error for functions defined in THIS module that we know
2268
+ # are NOT actions (i.e., async functions without @action decorator)
2269
+ # We can't reliably detect imported non-actions without full type info.
2270
+ #
2271
+ # The check works by looking at _module_functions which contains
2272
+ # functions defined in the same module as the workflow.
2273
+ if func_name in getattr(self, "_module_functions", set()):
2274
+ if func_name not in self._action_defs:
2275
+ line = getattr(node, "lineno", None)
2276
+ col = getattr(node, "col_offset", None)
2277
+ raise UnsupportedPatternError(
2278
+ f"Awaiting non-action function '{func_name}()' is not supported",
2279
+ RECOMMENDATIONS["non_action_call"],
2280
+ line=line,
2281
+ col=col,
2282
+ )
2283
+
2284
+ def _check_sync_function_call(self, node: ast.Call) -> None:
2285
+ """Check for synchronous function calls that should be in actions.
2286
+
2287
+ Common patterns like len(), str(), etc. are not supported in workflow code.
2288
+ """
2289
+ func_name = None
2290
+ if isinstance(node.func, ast.Name):
2291
+ func_name = node.func.id
2292
+ elif isinstance(node.func, ast.Attribute):
2293
+ # Method calls on objects - check the method name
2294
+ func_name = node.func.attr
2295
+
2296
+ if not func_name:
2297
+ return
2298
+
2299
+ # Builtins that users commonly try to use
2300
+ common_builtins = {
2301
+ "len",
2302
+ "str",
2303
+ "int",
2304
+ "float",
2305
+ "bool",
2306
+ "list",
2307
+ "dict",
2308
+ "set",
2309
+ "tuple",
2310
+ "sum",
2311
+ "min",
2312
+ "max",
2313
+ "sorted",
2314
+ "reversed",
2315
+ "enumerate",
2316
+ "zip",
2317
+ "map",
2318
+ "filter",
2319
+ "range",
2320
+ "abs",
2321
+ "round",
2322
+ "print",
2323
+ "type",
2324
+ "isinstance",
2325
+ "hasattr",
2326
+ "getattr",
2327
+ "setattr",
2328
+ "open",
2329
+ "format",
2330
+ }
2331
+
2332
+ if func_name in common_builtins:
2333
+ line = getattr(node, "lineno", None)
2334
+ col = getattr(node, "col_offset", None)
2335
+ raise UnsupportedPatternError(
2336
+ f"Calling built-in function '{func_name}()' directly is not supported",
2337
+ RECOMMENDATIONS["builtin_call"],
2338
+ line=line,
2339
+ col=col,
2340
+ )
2341
+
2342
+ def _extract_action_call(self, node: ast.expr) -> Optional[ir.ActionCall]:
2343
+ """Extract an action call from an expression if present.
2344
+
2345
+ Also validates that awaited calls are actually @action decorated functions.
2346
+ Raises UnsupportedPatternError if awaiting a non-action function.
2347
+ """
2348
+ if not isinstance(node, ast.Await):
2349
+ return None
2350
+
2351
+ awaited = node.value
2352
+ # Handle self.run_action(...) wrapper
2353
+ if isinstance(awaited, ast.Call):
2354
+ if self._is_run_action_call(awaited):
2355
+ # Extract the actual action call from run_action
2356
+ if awaited.args:
2357
+ action_call = self._extract_action_call_from_awaitable(awaited.args[0])
2358
+ if action_call:
2359
+ # Extract policies from run_action kwargs (retry, timeout)
2360
+ self._extract_policies_from_run_action(awaited, action_call)
2361
+ return action_call
2362
+ # Check for asyncio.sleep() - convert to @sleep action
2363
+ if self._is_asyncio_sleep_call(awaited):
2364
+ return self._convert_asyncio_sleep_to_action(awaited)
2365
+ # Try to extract as action call
2366
+ action_call = self._extract_action_call_from_call(awaited)
2367
+ if action_call:
2368
+ return action_call
2369
+
2370
+ # If we get here, it's an await of a non-action function
2371
+ self._check_non_action_await(node)
2372
+ return None
2373
+
2374
+ return None
2375
+
2376
+ def _is_run_action_call(self, node: ast.Call) -> bool:
2377
+ """Check if this is a self.run_action(...) call."""
2378
+ if isinstance(node.func, ast.Attribute):
2379
+ return node.func.attr == "run_action"
2380
+ return False
2381
+
2382
+ def _extract_policies_from_run_action(
2383
+ self, run_action_call: ast.Call, action_call: ir.ActionCall
2384
+ ) -> None:
2385
+ """Extract retry and timeout policies from run_action kwargs.
2386
+
2387
+ Parses patterns like:
2388
+ - self.run_action(action(), retry=RetryPolicy(attempts=3))
2389
+ - self.run_action(action(), timeout=timedelta(seconds=30))
2390
+ - self.run_action(action(), timeout=60)
2391
+ """
2392
+ for kw in run_action_call.keywords:
2393
+ if kw.arg == "retry":
2394
+ retry_policy = self._parse_retry_policy(kw.value)
2395
+ if retry_policy:
2396
+ policy_bracket = ir.PolicyBracket()
2397
+ policy_bracket.retry.CopyFrom(retry_policy)
2398
+ action_call.policies.append(policy_bracket)
2399
+ elif kw.arg == "timeout":
2400
+ timeout_policy = self._parse_timeout_policy(kw.value)
2401
+ if timeout_policy:
2402
+ policy_bracket = ir.PolicyBracket()
2403
+ policy_bracket.timeout.CopyFrom(timeout_policy)
2404
+ action_call.policies.append(policy_bracket)
2405
+
2406
+ def _parse_retry_policy(self, node: ast.expr) -> Optional[ir.RetryPolicy]:
2407
+ """Parse a RetryPolicy(...) call into IR.
2408
+
2409
+ Supports:
2410
+ - RetryPolicy(attempts=3)
2411
+ - RetryPolicy(attempts=3, exception_types=["ValueError"])
2412
+ - RetryPolicy(attempts=3, backoff_seconds=5)
2413
+ """
2414
+ if not isinstance(node, ast.Call):
2415
+ return None
2416
+
2417
+ # Check if it's a RetryPolicy call
2418
+ func_name = None
2419
+ if isinstance(node.func, ast.Name):
2420
+ func_name = node.func.id
2421
+ elif isinstance(node.func, ast.Attribute):
2422
+ func_name = node.func.attr
2423
+
2424
+ if func_name != "RetryPolicy":
2425
+ return None
2426
+
2427
+ policy = ir.RetryPolicy()
2428
+
2429
+ for kw in node.keywords:
2430
+ if kw.arg == "attempts" and isinstance(kw.value, ast.Constant):
2431
+ policy.max_retries = kw.value.value
2432
+ elif kw.arg == "exception_types" and isinstance(kw.value, ast.List):
2433
+ for elt in kw.value.elts:
2434
+ if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
2435
+ policy.exception_types.append(elt.value)
2436
+ elif kw.arg == "backoff_seconds" and isinstance(kw.value, ast.Constant):
2437
+ policy.backoff.seconds = int(kw.value.value)
2438
+
2439
+ return policy
2440
+
2441
+ def _parse_timeout_policy(self, node: ast.expr) -> Optional[ir.TimeoutPolicy]:
2442
+ """Parse a timeout value into IR.
2443
+
2444
+ Supports:
2445
+ - timeout=60 (int seconds)
2446
+ - timeout=30.5 (float seconds)
2447
+ - timeout=timedelta(seconds=30)
2448
+ - timeout=timedelta(minutes=2)
2449
+ """
2450
+ policy = ir.TimeoutPolicy()
2451
+
2452
+ # Direct numeric value (seconds)
2453
+ if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
2454
+ policy.timeout.seconds = int(node.value)
2455
+ return policy
2456
+
2457
+ # timedelta(...) call
2458
+ if isinstance(node, ast.Call):
2459
+ func_name = None
2460
+ if isinstance(node.func, ast.Name):
2461
+ func_name = node.func.id
2462
+ elif isinstance(node.func, ast.Attribute):
2463
+ func_name = node.func.attr
2464
+
2465
+ if func_name == "timedelta":
2466
+ total_seconds = 0
2467
+ for kw in node.keywords:
2468
+ if isinstance(kw.value, ast.Constant):
2469
+ val = kw.value.value
2470
+ if kw.arg == "seconds":
2471
+ total_seconds += int(val)
2472
+ elif kw.arg == "minutes":
2473
+ total_seconds += int(val) * 60
2474
+ elif kw.arg == "hours":
2475
+ total_seconds += int(val) * 3600
2476
+ elif kw.arg == "days":
2477
+ total_seconds += int(val) * 86400
2478
+ policy.timeout.seconds = total_seconds
2479
+ return policy
2480
+
2481
+ return None
2482
+
2483
+ def _is_asyncio_sleep_call(self, node: ast.Call) -> bool:
2484
+ """Check if this is an asyncio.sleep(...) call.
2485
+
2486
+ Supports both patterns:
2487
+ - import asyncio; asyncio.sleep(1)
2488
+ - from asyncio import sleep; sleep(1)
2489
+ - from asyncio import sleep as s; s(1)
2490
+ """
2491
+ if isinstance(node.func, ast.Attribute):
2492
+ # asyncio.sleep(...) pattern
2493
+ if node.func.attr == "sleep" and isinstance(node.func.value, ast.Name):
2494
+ return node.func.value.id == "asyncio"
2495
+ elif isinstance(node.func, ast.Name):
2496
+ # sleep(...) pattern - check if it's imported from asyncio
2497
+ func_name = node.func.id
2498
+ if func_name in self._imported_names:
2499
+ imported = self._imported_names[func_name]
2500
+ return imported.module == "asyncio" and imported.original_name == "sleep"
2501
+ return False
2502
+
2503
+ def _convert_asyncio_sleep_to_action(self, node: ast.Call) -> ir.ActionCall:
2504
+ """Convert asyncio.sleep(duration) to @sleep(duration=X) action call.
2505
+
2506
+ This creates a built-in sleep action that the scheduler handles as a
2507
+ durable sleep - stored in the DB with a future scheduled_at time.
2508
+ """
2509
+ action_call = ir.ActionCall(action_name="sleep")
2510
+
2511
+ # Extract duration argument (positional or keyword)
2512
+ if node.args:
2513
+ # asyncio.sleep(1) - positional
2514
+ expr = _expr_to_ir(node.args[0])
2515
+ if expr:
2516
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2517
+ elif node.keywords:
2518
+ # asyncio.sleep(seconds=1) - keyword (less common)
2519
+ for kw in node.keywords:
2520
+ if kw.arg in ("seconds", "delay", "duration"):
2521
+ expr = _expr_to_ir(kw.value)
2522
+ if expr:
2523
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2524
+ break
2525
+
2526
+ return action_call
2527
+
2528
+ def _is_asyncio_gather_call(self, node: ast.Call) -> bool:
2529
+ """Check if this is an asyncio.gather(...) call.
2530
+
2531
+ Supports both patterns:
2532
+ - import asyncio; asyncio.gather(a(), b())
2533
+ - from asyncio import gather; gather(a(), b())
2534
+ - from asyncio import gather as g; g(a(), b())
2535
+ """
2536
+ if isinstance(node.func, ast.Attribute):
2537
+ # asyncio.gather(...) pattern
2538
+ if node.func.attr == "gather" and isinstance(node.func.value, ast.Name):
2539
+ return node.func.value.id == "asyncio"
2540
+ elif isinstance(node.func, ast.Name):
2541
+ # gather(...) pattern - check if it's imported from asyncio
2542
+ func_name = node.func.id
2543
+ if func_name in self._imported_names:
2544
+ imported = self._imported_names[func_name]
2545
+ return imported.module == "asyncio" and imported.original_name == "gather"
2546
+ return False
2547
+
2548
+ def _convert_asyncio_gather(
2549
+ self, node: ast.Call
2550
+ ) -> Optional[Union[ir.ParallelExpr, ir.SpreadExpr]]:
2551
+ """Convert asyncio.gather(...) to ParallelExpr or SpreadExpr IR.
2552
+
2553
+ Handles two patterns:
2554
+ 1. Static gather: asyncio.gather(a(), b(), c()) -> ParallelExpr
2555
+ 2. Spread gather: asyncio.gather(*[action(x) for x in items]) -> SpreadExpr
2556
+
2557
+ Args:
2558
+ node: The asyncio.gather() Call node
2559
+
2560
+ Returns:
2561
+ A ParallelExpr, SpreadExpr, or None if conversion fails.
2562
+ """
2563
+ # Check for starred expressions - spread pattern
2564
+ if len(node.args) == 1 and isinstance(node.args[0], ast.Starred):
2565
+ starred = node.args[0]
2566
+ # Only list comprehensions are supported for spread
2567
+ if isinstance(starred.value, ast.ListComp):
2568
+ return self._convert_listcomp_to_spread_expr(starred.value)
2569
+ else:
2570
+ # Spreading a variable or other expression is not supported
2571
+ line = getattr(node, "lineno", None)
2572
+ col = getattr(node, "col_offset", None)
2573
+ if isinstance(starred.value, ast.Name):
2574
+ var_name = starred.value.id
2575
+ raise UnsupportedPatternError(
2576
+ f"Spreading variable '{var_name}' in asyncio.gather() is not supported",
2577
+ RECOMMENDATIONS["gather_variable_spread"],
2578
+ line=line,
2579
+ col=col,
2580
+ )
2581
+ else:
2582
+ raise UnsupportedPatternError(
2583
+ "Spreading non-list-comprehension expressions in asyncio.gather() is not supported",
2584
+ RECOMMENDATIONS["gather_variable_spread"],
2585
+ line=line,
2586
+ col=col,
2587
+ )
2588
+
2589
+ # Standard case: gather(a(), b(), c()) -> ParallelExpr
2590
+ parallel = ir.ParallelExpr()
2591
+
2592
+ # Each argument to gather() should be an action call
2593
+ for arg in node.args:
2594
+ call = self._convert_gather_arg_to_call(arg)
2595
+ if call:
2596
+ parallel.calls.append(call)
2597
+
2598
+ # Only return if we have calls
2599
+ if not parallel.calls:
2600
+ return None
2601
+
2602
+ return parallel
2603
+
2604
+ def _convert_listcomp_to_spread_expr(self, listcomp: ast.ListComp) -> Optional[ir.SpreadExpr]:
2605
+ """Convert a list comprehension to SpreadExpr IR.
2606
+
2607
+ Handles patterns like:
2608
+ [action(x=item) for item in collection]
2609
+
2610
+ The comprehension must have exactly one generator with no conditions,
2611
+ and the element must be an action call.
2612
+
2613
+ Args:
2614
+ listcomp: The ListComp AST node
2615
+
2616
+ Returns:
2617
+ A SpreadExpr, or None if conversion fails.
2618
+ """
2619
+ # Only support simple list comprehensions with one generator
2620
+ if len(listcomp.generators) != 1:
2621
+ line = getattr(listcomp, "lineno", None)
2622
+ col = getattr(listcomp, "col_offset", None)
2623
+ raise UnsupportedPatternError(
2624
+ "Spread pattern only supports a single loop variable",
2625
+ "Use a simple list comprehension: [action(x) for x in items]",
2626
+ line=line,
2627
+ col=col,
2628
+ )
2629
+
2630
+ gen = listcomp.generators[0]
2631
+
2632
+ # Check for conditions - not supported
2633
+ if gen.ifs:
2634
+ line = getattr(listcomp, "lineno", None)
2635
+ col = getattr(listcomp, "col_offset", None)
2636
+ raise UnsupportedPatternError(
2637
+ "Spread pattern does not support conditions in list comprehension",
2638
+ "Remove the 'if' clause from the comprehension",
2639
+ line=line,
2640
+ col=col,
2641
+ )
2642
+
2643
+ # Get the loop variable name
2644
+ if not isinstance(gen.target, ast.Name):
2645
+ line = getattr(listcomp, "lineno", None)
2646
+ col = getattr(listcomp, "col_offset", None)
2647
+ raise UnsupportedPatternError(
2648
+ "Spread pattern requires a simple loop variable",
2649
+ "Use a simple variable: [action(x) for x in items]",
2650
+ line=line,
2651
+ col=col,
2652
+ )
2653
+ loop_var = gen.target.id
2654
+
2655
+ # Get the collection expression
2656
+ collection_expr = _expr_to_ir(gen.iter)
2657
+ if not collection_expr:
2658
+ line = getattr(listcomp, "lineno", None)
2659
+ col = getattr(listcomp, "col_offset", None)
2660
+ raise UnsupportedPatternError(
2661
+ "Could not convert collection expression in spread pattern",
2662
+ "Ensure the collection is a simple variable or expression",
2663
+ line=line,
2664
+ col=col,
2665
+ )
2666
+
2667
+ # The element must be an action call
2668
+ if not isinstance(listcomp.elt, ast.Call):
2669
+ line = getattr(listcomp, "lineno", None)
2670
+ col = getattr(listcomp, "col_offset", None)
2671
+ raise UnsupportedPatternError(
2672
+ "Spread pattern requires an action call in the list comprehension",
2673
+ "Use: [action(x=item) for item in items]",
2674
+ line=line,
2675
+ col=col,
2676
+ )
2677
+
2678
+ action_call = self._extract_action_call_from_call(listcomp.elt)
2679
+ if not action_call:
2680
+ line = getattr(listcomp, "lineno", None)
2681
+ col = getattr(listcomp, "col_offset", None)
2682
+ raise UnsupportedPatternError(
2683
+ "Spread pattern element must be an @action call",
2684
+ "Ensure the function is decorated with @action",
2685
+ line=line,
2686
+ col=col,
2687
+ )
2688
+
2689
+ # Build the SpreadExpr
2690
+ spread = ir.SpreadExpr()
2691
+ spread.collection.CopyFrom(collection_expr)
2692
+ spread.loop_var = loop_var
2693
+ spread.action.CopyFrom(action_call)
2694
+
2695
+ return spread
2696
+
2697
+ def _convert_gather_arg_to_call(self, node: ast.expr) -> Optional[ir.Call]:
2698
+ """Convert a gather argument to an IR Call.
2699
+
2700
+ Handles both action calls and regular function calls.
2701
+ """
2702
+ if not isinstance(node, ast.Call):
2703
+ return None
2704
+
2705
+ # Try to extract as an action call first
2706
+ action_call = self._extract_action_call_from_call(node)
2707
+ if action_call:
2708
+ call = ir.Call()
2709
+ call.action.CopyFrom(action_call)
2710
+ return call
2711
+
2712
+ # Fall back to regular function call
2713
+ func_call = self._convert_to_function_call(node)
2714
+ if func_call:
2715
+ call = ir.Call()
2716
+ call.function.CopyFrom(func_call)
2717
+ return call
2718
+
2719
+ return None
2720
+
2721
+ def _convert_to_function_call(self, node: ast.Call) -> Optional[ir.FunctionCall]:
2722
+ """Convert an AST Call to IR FunctionCall."""
2723
+ func_name = self._get_func_name(node.func)
2724
+ if not func_name:
2725
+ return None
2726
+
2727
+ fn_call = ir.FunctionCall(name=func_name)
2728
+
2729
+ # Add positional args
2730
+ for arg in node.args:
2731
+ expr = _expr_to_ir(arg)
2732
+ if expr:
2733
+ fn_call.args.append(expr)
2734
+
2735
+ # Add keyword args
2736
+ for kw in node.keywords:
2737
+ if kw.arg:
2738
+ expr = _expr_to_ir(kw.value)
2739
+ if expr:
2740
+ fn_call.kwargs.append(ir.Kwarg(name=kw.arg, value=expr))
2741
+
2742
+ return fn_call
2743
+
2744
+ def _get_func_name(self, node: ast.expr) -> Optional[str]:
2745
+ """Get function name from a func node."""
2746
+ if isinstance(node, ast.Name):
2747
+ return node.id
2748
+ elif isinstance(node, ast.Attribute):
2749
+ # Handle chained attributes like obj.method
2750
+ parts = []
2751
+ current = node
2752
+ while isinstance(current, ast.Attribute):
2753
+ parts.append(current.attr)
2754
+ current = current.value
2755
+ if isinstance(current, ast.Name):
2756
+ parts.append(current.id)
2757
+ return ".".join(reversed(parts))
2758
+ return None
2759
+
2760
+ def _expr_to_ir_with_model_coercion(self, node: ast.expr) -> Optional[ir.Expr]:
2761
+ """Convert an AST expression to IR, converting model constructors to dicts.
2762
+
2763
+ This is used for action arguments where Pydantic models or dataclass
2764
+ constructors should be converted to dict expressions that Rust can evaluate.
2765
+
2766
+ If the expression is a model constructor (e.g., MyModel(field=value)),
2767
+ it is converted to a dict expression. Otherwise, falls back to the
2768
+ standard _expr_to_ir conversion.
2769
+ """
2770
+ # Check if this is a model constructor call
2771
+ if isinstance(node, ast.Call):
2772
+ model_name = self._is_model_constructor(node)
2773
+ if model_name:
2774
+ return self._convert_model_constructor_to_dict(node, model_name)
2775
+
2776
+ # Fall back to standard expression conversion
2777
+ return _expr_to_ir(node)
2778
+
2779
+ def _extract_action_call_from_awaitable(self, node: ast.expr) -> Optional[ir.ActionCall]:
2780
+ """Extract action call from an awaitable expression."""
2781
+ if isinstance(node, ast.Call):
2782
+ return self._extract_action_call_from_call(node)
2783
+ return None
2784
+
2785
+ def _extract_action_call_from_call(self, node: ast.Call) -> Optional[ir.ActionCall]:
2786
+ """Extract action call info from a Call node.
2787
+
2788
+ Converts positional arguments to keyword arguments using the action's
2789
+ signature introspection. This ensures all arguments are named in the IR.
2790
+
2791
+ Pydantic models and dataclass constructors passed as arguments are
2792
+ automatically converted to dict expressions.
2793
+ """
2794
+ action_name = self._get_action_name(node.func)
2795
+ if not action_name:
2796
+ return None
2797
+
2798
+ if action_name not in self._action_defs:
2799
+ return None
2800
+
2801
+ action_def = self._action_defs[action_name]
2802
+ action_call = ir.ActionCall(action_name=action_def.action_name)
2803
+
2804
+ # Set the module name so the worker knows where to find the action
2805
+ if action_def.module_name:
2806
+ action_call.module_name = action_def.module_name
2807
+
2808
+ # Get parameter names from signature for positional arg conversion
2809
+ param_names = list(action_def.signature.parameters.keys())
2810
+
2811
+ # Convert positional args to kwargs using signature introspection
2812
+ # Model constructors are converted to dict expressions
2813
+ for i, arg in enumerate(node.args):
2814
+ if i < len(param_names):
2815
+ expr = self._expr_to_ir_with_model_coercion(arg)
2816
+ if expr:
2817
+ kwarg = ir.Kwarg(name=param_names[i], value=expr)
2818
+ action_call.kwargs.append(kwarg)
2819
+
2820
+ # Add explicit kwargs
2821
+ # Model constructors are converted to dict expressions
2822
+ for kw in node.keywords:
2823
+ if kw.arg:
2824
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
2825
+ if expr:
2826
+ kwarg = ir.Kwarg(name=kw.arg, value=expr)
2827
+ action_call.kwargs.append(kwarg)
2828
+
2829
+ return action_call
2830
+
2831
+ def _get_action_name(self, func: ast.expr) -> Optional[str]:
2832
+ """Get the action name from a function expression."""
2833
+ if isinstance(func, ast.Name):
2834
+ return func.id
2835
+ elif isinstance(func, ast.Attribute):
2836
+ return func.attr
2837
+ return None
2838
+
2839
+ def _get_assign_target(self, targets: List[ast.expr]) -> Optional[str]:
2840
+ """Get the target variable name from assignment targets (single target only)."""
2841
+ if targets and isinstance(targets[0], ast.Name):
2842
+ return targets[0].id
2843
+ return None
2844
+
2845
+ def _get_assign_targets(self, targets: List[ast.expr]) -> List[str]:
2846
+ """Get all target variable names from assignment targets (including tuple unpacking)."""
2847
+ result: List[str] = []
2848
+ for t in targets:
2849
+ if isinstance(t, ast.Name):
2850
+ result.append(t.id)
2851
+ elif isinstance(t, ast.Subscript):
2852
+ formatted = _format_subscript_target(t)
2853
+ if formatted:
2854
+ result.append(formatted)
2855
+ elif isinstance(t, ast.Tuple):
2856
+ for elt in t.elts:
2857
+ if isinstance(elt, ast.Name):
2858
+ result.append(elt.id)
2859
+ return result
2860
+
2861
+
2862
+ def _make_span(node: ast.AST) -> ir.Span:
2863
+ """Create a Span from an AST node."""
2864
+ return ir.Span(
2865
+ start_line=getattr(node, "lineno", 0),
2866
+ start_col=getattr(node, "col_offset", 0),
2867
+ end_line=getattr(node, "end_lineno", 0) or 0,
2868
+ end_col=getattr(node, "end_col_offset", 0) or 0,
2869
+ )
2870
+
2871
+
2872
+ def _expr_to_ir(expr: ast.AST) -> Optional[ir.Expr]:
2873
+ """Convert Python AST expression to IR Expr."""
2874
+ result = ir.Expr(span=_make_span(expr))
2875
+
2876
+ if isinstance(expr, ast.Name):
2877
+ result.variable.CopyFrom(ir.Variable(name=expr.id))
2878
+ return result
2879
+
2880
+ if isinstance(expr, ast.Constant):
2881
+ literal = _constant_to_literal(expr.value)
2882
+ if literal:
2883
+ result.literal.CopyFrom(literal)
2884
+ return result
2885
+
2886
+ if isinstance(expr, ast.BinOp):
2887
+ left = _expr_to_ir(expr.left)
2888
+ right = _expr_to_ir(expr.right)
2889
+ op = _bin_op_to_ir(expr.op)
2890
+ if left and right and op:
2891
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
2892
+ return result
2893
+
2894
+ if isinstance(expr, ast.UnaryOp):
2895
+ operand = _expr_to_ir(expr.operand)
2896
+ op = _unary_op_to_ir(expr.op)
2897
+ if operand and op:
2898
+ result.unary_op.CopyFrom(ir.UnaryOp(op=op, operand=operand))
2899
+ return result
2900
+
2901
+ if isinstance(expr, ast.Compare):
2902
+ left = _expr_to_ir(expr.left)
2903
+ if not left:
2904
+ return None
2905
+ # For simplicity, handle single comparison
2906
+ if expr.ops and expr.comparators:
2907
+ op = _cmp_op_to_ir(expr.ops[0])
2908
+ right = _expr_to_ir(expr.comparators[0])
2909
+ if op and right:
2910
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
2911
+ return result
2912
+
2913
+ if isinstance(expr, ast.BoolOp):
2914
+ values = [_expr_to_ir(v) for v in expr.values]
2915
+ if all(v for v in values):
2916
+ op = _bool_op_to_ir(expr.op)
2917
+ if op and len(values) >= 2:
2918
+ # Chain boolean ops: a and b and c -> (a and b) and c
2919
+ result_expr = values[0]
2920
+ for v in values[1:]:
2921
+ if result_expr and v:
2922
+ new_result = ir.Expr()
2923
+ new_result.binary_op.CopyFrom(ir.BinaryOp(left=result_expr, op=op, right=v))
2924
+ result_expr = new_result
2925
+ return result_expr
2926
+
2927
+ if isinstance(expr, ast.List):
2928
+ elements = [_expr_to_ir(e) for e in expr.elts]
2929
+ if all(e for e in elements):
2930
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
2931
+ result.list.CopyFrom(list_expr)
2932
+ return result
2933
+
2934
+ if isinstance(expr, ast.Dict):
2935
+ entries: List[ir.DictEntry] = []
2936
+ for k, v in zip(expr.keys, expr.values, strict=False):
2937
+ if k:
2938
+ key_expr = _expr_to_ir(k)
2939
+ value_expr = _expr_to_ir(v)
2940
+ if key_expr and value_expr:
2941
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2942
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2943
+ return result
2944
+
2945
+ if isinstance(expr, ast.Subscript):
2946
+ obj = _expr_to_ir(expr.value)
2947
+ index = _expr_to_ir(expr.slice) if isinstance(expr.slice, ast.AST) else None
2948
+ if obj and index:
2949
+ result.index.CopyFrom(ir.IndexAccess(object=obj, index=index))
2950
+ return result
2951
+
2952
+ if isinstance(expr, ast.Attribute):
2953
+ obj = _expr_to_ir(expr.value)
2954
+ if obj:
2955
+ result.dot.CopyFrom(ir.DotAccess(object=obj, attribute=expr.attr))
2956
+ return result
2957
+
2958
+ if isinstance(expr, ast.Call):
2959
+ # Function call
2960
+ func_name = _get_func_name(expr.func)
2961
+ if func_name:
2962
+ args = [_expr_to_ir(a) for a in expr.args]
2963
+ kwargs: List[ir.Kwarg] = []
2964
+ for kw in expr.keywords:
2965
+ if kw.arg:
2966
+ kw_expr = _expr_to_ir(kw.value)
2967
+ if kw_expr:
2968
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
2969
+ func_call = ir.FunctionCall(
2970
+ name=func_name,
2971
+ args=[a for a in args if a],
2972
+ kwargs=kwargs,
2973
+ )
2974
+ result.function_call.CopyFrom(func_call)
2975
+ return result
2976
+
2977
+ if isinstance(expr, ast.Tuple):
2978
+ # Handle tuple as list for now
2979
+ elements = [_expr_to_ir(e) for e in expr.elts]
2980
+ if all(e for e in elements):
2981
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
2982
+ result.list.CopyFrom(list_expr)
2983
+ return result
2984
+
2985
+ # Check for unsupported expression types
2986
+ _check_unsupported_expression(expr)
2987
+
2988
+ return None
2989
+
2990
+
2991
+ def _check_unsupported_expression(expr: ast.AST) -> None:
2992
+ """Check for unsupported expression types and raise descriptive errors."""
2993
+ line = getattr(expr, "lineno", None)
2994
+ col = getattr(expr, "col_offset", None)
2995
+
2996
+ if isinstance(expr, ast.JoinedStr):
2997
+ raise UnsupportedPatternError(
2998
+ "F-strings are not supported",
2999
+ RECOMMENDATIONS["fstring"],
3000
+ line=line,
3001
+ col=col,
3002
+ )
3003
+ elif isinstance(expr, ast.Lambda):
3004
+ raise UnsupportedPatternError(
3005
+ "Lambda expressions are not supported",
3006
+ RECOMMENDATIONS["lambda"],
3007
+ line=line,
3008
+ col=col,
3009
+ )
3010
+ elif isinstance(expr, ast.ListComp):
3011
+ raise UnsupportedPatternError(
3012
+ "List comprehensions are not supported in this context",
3013
+ RECOMMENDATIONS["list_comprehension"],
3014
+ line=line,
3015
+ col=col,
3016
+ )
3017
+ elif isinstance(expr, ast.DictComp):
3018
+ raise UnsupportedPatternError(
3019
+ "Dict comprehensions are not supported in this context",
3020
+ RECOMMENDATIONS["dict_comprehension"],
3021
+ line=line,
3022
+ col=col,
3023
+ )
3024
+ elif isinstance(expr, ast.SetComp):
3025
+ raise UnsupportedPatternError(
3026
+ "Set comprehensions are not supported",
3027
+ RECOMMENDATIONS["set_comprehension"],
3028
+ line=line,
3029
+ col=col,
3030
+ )
3031
+ elif isinstance(expr, ast.GeneratorExp):
3032
+ raise UnsupportedPatternError(
3033
+ "Generator expressions are not supported",
3034
+ RECOMMENDATIONS["generator"],
3035
+ line=line,
3036
+ col=col,
3037
+ )
3038
+ elif isinstance(expr, ast.NamedExpr):
3039
+ raise UnsupportedPatternError(
3040
+ "The walrus operator (:=) is not supported",
3041
+ RECOMMENDATIONS["walrus"],
3042
+ line=line,
3043
+ col=col,
3044
+ )
3045
+ elif isinstance(expr, ast.Yield) or isinstance(expr, ast.YieldFrom):
3046
+ raise UnsupportedPatternError(
3047
+ "Yield expressions are not supported",
3048
+ RECOMMENDATIONS["yield_statement"],
3049
+ line=line,
3050
+ col=col,
3051
+ )
3052
+
3053
+
3054
+ def _format_subscript_target(target: ast.Subscript) -> Optional[str]:
3055
+ """Convert a subscript target to a string representation for tracking."""
3056
+ if not isinstance(target.value, ast.Name):
3057
+ return None
3058
+
3059
+ base = target.value.id
3060
+ try:
3061
+ index_str = ast.unparse(target.slice)
3062
+ except Exception:
3063
+ return None
3064
+
3065
+ return f"{base}[{index_str}]"
3066
+
3067
+
3068
+ def _constant_to_literal(value: Any) -> Optional[ir.Literal]:
3069
+ """Convert a Python constant to IR Literal."""
3070
+ literal = ir.Literal()
3071
+ if value is None:
3072
+ literal.is_none = True
3073
+ elif isinstance(value, bool):
3074
+ literal.bool_value = value
3075
+ elif isinstance(value, int):
3076
+ literal.int_value = value
3077
+ elif isinstance(value, float):
3078
+ literal.float_value = value
3079
+ elif isinstance(value, str):
3080
+ literal.string_value = value
3081
+ else:
3082
+ return None
3083
+ return literal
3084
+
3085
+
3086
+ def _bin_op_to_ir(op: ast.operator) -> Optional[ir.BinaryOperator]:
3087
+ """Convert Python binary operator to IR BinaryOperator."""
3088
+ mapping = {
3089
+ ast.Add: ir.BinaryOperator.BINARY_OP_ADD,
3090
+ ast.Sub: ir.BinaryOperator.BINARY_OP_SUB,
3091
+ ast.Mult: ir.BinaryOperator.BINARY_OP_MUL,
3092
+ ast.Div: ir.BinaryOperator.BINARY_OP_DIV,
3093
+ ast.FloorDiv: ir.BinaryOperator.BINARY_OP_FLOOR_DIV,
3094
+ ast.Mod: ir.BinaryOperator.BINARY_OP_MOD,
3095
+ }
3096
+ return mapping.get(type(op))
3097
+
3098
+
3099
+ def _unary_op_to_ir(op: ast.unaryop) -> Optional[ir.UnaryOperator]:
3100
+ """Convert Python unary operator to IR UnaryOperator."""
3101
+ mapping = {
3102
+ ast.USub: ir.UnaryOperator.UNARY_OP_NEG,
3103
+ ast.Not: ir.UnaryOperator.UNARY_OP_NOT,
3104
+ }
3105
+ return mapping.get(type(op))
3106
+
3107
+
3108
+ def _cmp_op_to_ir(op: ast.cmpop) -> Optional[ir.BinaryOperator]:
3109
+ """Convert Python comparison operator to IR BinaryOperator."""
3110
+ mapping = {
3111
+ ast.Eq: ir.BinaryOperator.BINARY_OP_EQ,
3112
+ ast.NotEq: ir.BinaryOperator.BINARY_OP_NE,
3113
+ ast.Lt: ir.BinaryOperator.BINARY_OP_LT,
3114
+ ast.LtE: ir.BinaryOperator.BINARY_OP_LE,
3115
+ ast.Gt: ir.BinaryOperator.BINARY_OP_GT,
3116
+ ast.GtE: ir.BinaryOperator.BINARY_OP_GE,
3117
+ ast.In: ir.BinaryOperator.BINARY_OP_IN,
3118
+ ast.NotIn: ir.BinaryOperator.BINARY_OP_NOT_IN,
3119
+ }
3120
+ return mapping.get(type(op))
3121
+
3122
+
3123
+ def _bool_op_to_ir(op: ast.boolop) -> Optional[ir.BinaryOperator]:
3124
+ """Convert Python boolean operator to IR BinaryOperator."""
3125
+ mapping = {
3126
+ ast.And: ir.BinaryOperator.BINARY_OP_AND,
3127
+ ast.Or: ir.BinaryOperator.BINARY_OP_OR,
3128
+ }
3129
+ return mapping.get(type(op))
3130
+
3131
+
3132
+ def _get_func_name(func: ast.expr) -> Optional[str]:
3133
+ """Get function name from a Call's func attribute."""
3134
+ if isinstance(func, ast.Name):
3135
+ return func.id
3136
+ elif isinstance(func, ast.Attribute):
3137
+ # For method calls like obj.method, return full dotted name
3138
+ parts = []
3139
+ current = func
3140
+ while isinstance(current, ast.Attribute):
3141
+ parts.append(current.attr)
3142
+ current = current.value
3143
+ if isinstance(current, ast.Name):
3144
+ parts.append(current.id)
3145
+ return ".".join(reversed(parts))
3146
+ return None