rappel 0.5.7__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rappel/ir_builder.py ADDED
@@ -0,0 +1,3191 @@
1
+ """
2
+ IR Builder - Converts Python workflow AST to Rappel IR (ast.proto).
3
+
4
+ This module parses Python workflow classes and produces the IR representation
5
+ that can be sent to the Rust runtime for execution.
6
+
7
+ The IR builder performs deep transformations to convert Python patterns into
8
+ valid Rappel IR structures. Control flow bodies (try, for, if branches) are
9
+ represented as full blocks, so they can contain arbitrary statements and
10
+ `return` behaves consistently with Python (returns from the entry function end
11
+ the workflow).
12
+
13
+ Validation:
14
+ The IR builder proactively detects unsupported Python patterns and raises
15
+ UnsupportedPatternError with clear recommendations for how to rewrite the code.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ import copy
22
+ import inspect
23
+ import textwrap
24
+ from dataclasses import dataclass
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
26
+
27
+ from proto import ast_pb2 as ir
28
+
29
+
30
+ class UnsupportedPatternError(Exception):
31
+ """Raised when the IR builder encounters an unsupported Python pattern.
32
+
33
+ This error includes a recommendation for how to rewrite the code to use
34
+ supported patterns.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ message: str,
40
+ recommendation: str,
41
+ line: Optional[int] = None,
42
+ col: Optional[int] = None,
43
+ filename: Optional[str] = None,
44
+ ):
45
+ self.message = message
46
+ self.recommendation = recommendation
47
+ self.line = line
48
+ self.col = col
49
+ self.filename = filename
50
+
51
+ location_parts: List[str] = []
52
+ if filename:
53
+ location_parts.append(filename)
54
+ if line:
55
+ location_parts.append(f"line {line}")
56
+ if col is not None:
57
+ location_parts.append(f"col {col}")
58
+ location = f" ({', '.join(location_parts)})" if location_parts else ""
59
+ full_message = f"{message}{location}\n\nRecommendation: {recommendation}"
60
+ super().__init__(full_message)
61
+
62
+
63
+ # Recommendations for common unsupported patterns
64
+ RECOMMENDATIONS = {
65
+ "constructor_return": (
66
+ "Returning a class constructor (like MyModel(...)) directly is not supported.\n"
67
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
68
+ "Use an @action to create the object:\n\n"
69
+ " @action\n"
70
+ " async def build_result(items: list, count: int) -> MyResult:\n"
71
+ " return MyResult(items=items, count=count)\n\n"
72
+ " # In workflow:\n"
73
+ " return await build_result(items, count)"
74
+ ),
75
+ "constructor_assignment": (
76
+ "Assigning a class constructor result (like x = MyClass(...)) is not supported.\n"
77
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
78
+ "Use an @action to create the object:\n\n"
79
+ " @action\n"
80
+ " async def create_config(value: int) -> Config:\n"
81
+ " return Config(value=value)\n\n"
82
+ " # In workflow:\n"
83
+ " config = await create_config(value)"
84
+ ),
85
+ "non_action_call": (
86
+ "Calling a function that is not decorated with @action is not supported.\n"
87
+ "Only @action decorated functions can be awaited in workflow code.\n\n"
88
+ "Add the @action decorator to your function:\n\n"
89
+ " @action\n"
90
+ " async def my_function(x: int) -> int:\n"
91
+ " return x * 2"
92
+ ),
93
+ "sync_function_call": (
94
+ "Calling a synchronous function directly in workflow code is not supported.\n"
95
+ "All computation must happen inside @action decorated async functions.\n\n"
96
+ "Wrap your logic in an @action:\n\n"
97
+ " @action\n"
98
+ " async def compute(x: int) -> int:\n"
99
+ " return some_sync_function(x)"
100
+ ),
101
+ "method_call_non_self": (
102
+ "Calling methods on objects other than 'self' is not supported in workflow code.\n"
103
+ "Use an @action to perform method calls:\n\n"
104
+ " @action\n"
105
+ " async def call_method(obj: MyClass) -> Result:\n"
106
+ " return obj.some_method()"
107
+ ),
108
+ "builtin_call": (
109
+ "Calling built-in functions like len(), str(), int() directly is not supported.\n"
110
+ "Use an @action to perform these operations:\n\n"
111
+ " @action\n"
112
+ " async def get_length(items: list) -> int:\n"
113
+ " return len(items)"
114
+ ),
115
+ "fstring": (
116
+ "F-strings are not supported in workflow code because they require "
117
+ "runtime string interpolation.\n"
118
+ "Use an @action to perform string formatting:\n\n"
119
+ " @action\n"
120
+ " async def format_message(value: int) -> str:\n"
121
+ " return f'Result: {value}'"
122
+ ),
123
+ "delete": (
124
+ "The 'del' statement is not supported in workflow code.\n"
125
+ "Use an @action to perform mutations:\n\n"
126
+ " @action\n"
127
+ " async def remove_key(data: dict, key: str) -> dict:\n"
128
+ " del data[key]\n"
129
+ " return data"
130
+ ),
131
+ "while_loop": (
132
+ "While loops are not supported in workflow code because they can run "
133
+ "indefinitely.\n"
134
+ "Use a for loop with a fixed range, or restructure as recursive workflow calls."
135
+ ),
136
+ "with_statement": (
137
+ "Context managers (with statements) are not supported in workflow code.\n"
138
+ "Use an @action to handle resource management:\n\n"
139
+ " @action\n"
140
+ " async def read_file(path: str) -> str:\n"
141
+ " with open(path) as f:\n"
142
+ " return f.read()"
143
+ ),
144
+ "raise_statement": (
145
+ "The 'raise' statement is not supported directly in workflow code.\n"
146
+ "Use an @action that raises exceptions, or return error values."
147
+ ),
148
+ "assert_statement": (
149
+ "Assert statements are not supported in workflow code.\n"
150
+ "Use an @action for validation, or use if statements with explicit error handling."
151
+ ),
152
+ "lambda": (
153
+ "Lambda expressions are not supported in workflow code.\n"
154
+ "Use an @action to define the function logic."
155
+ ),
156
+ "list_comprehension": (
157
+ "List comprehensions are only supported when assigned directly to a variable "
158
+ "or inside asyncio.gather(*[...]).\n"
159
+ "For other cases, use a for loop or an @action."
160
+ ),
161
+ "dict_comprehension": (
162
+ "Dict comprehensions are only supported when assigned directly to a variable.\n"
163
+ "For other cases, use a for loop or an @action:\n\n"
164
+ " @action\n"
165
+ " async def build_dict(items: list) -> dict:\n"
166
+ " return {k: v for k, v in items}"
167
+ ),
168
+ "set_comprehension": (
169
+ "Set comprehensions are not supported in workflow code.\nUse an @action to build sets."
170
+ ),
171
+ "generator": (
172
+ "Generator expressions are not supported in workflow code.\n"
173
+ "Use a list or an @action instead."
174
+ ),
175
+ "walrus": (
176
+ "The walrus operator (:=) is not supported in workflow code.\n"
177
+ "Use separate assignment statements instead."
178
+ ),
179
+ "match": (
180
+ "Match statements are not supported in workflow code.\nUse if/elif/else chains instead."
181
+ ),
182
+ "gather_variable_spread": (
183
+ "Spreading a variable in asyncio.gather() is not supported because it requires "
184
+ "data flow analysis to determine the contents.\n"
185
+ "Use a list comprehension directly in gather:\n\n"
186
+ " # Instead of:\n"
187
+ " tasks = []\n"
188
+ " for i in range(count):\n"
189
+ " tasks.append(process(value=i))\n"
190
+ " results = await asyncio.gather(*tasks)\n\n"
191
+ " # Use:\n"
192
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
193
+ ),
194
+ "for_loop_append_pattern": (
195
+ "Building a task list in a for loop then spreading in asyncio.gather() is not "
196
+ "supported.\n"
197
+ "Use a list comprehension directly in gather:\n\n"
198
+ " # Instead of:\n"
199
+ " tasks = []\n"
200
+ " for i in range(count):\n"
201
+ " tasks.append(process(value=i))\n"
202
+ " results = await asyncio.gather(*tasks)\n\n"
203
+ " # Use:\n"
204
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
205
+ ),
206
+ "global_statement": (
207
+ "Global statements are not supported in workflow code.\n"
208
+ "Use workflow state or pass values explicitly."
209
+ ),
210
+ "nonlocal_statement": (
211
+ "Nonlocal statements are not supported in workflow code.\n"
212
+ "Use explicit parameter passing instead."
213
+ ),
214
+ "import_statement": (
215
+ "Import statements inside workflow run() are not supported.\n"
216
+ "Place imports at the module level."
217
+ ),
218
+ "class_def": (
219
+ "Class definitions inside workflow run() are not supported.\n"
220
+ "Define classes at the module level."
221
+ ),
222
+ "function_def": (
223
+ "Nested function definitions inside workflow run() are not supported.\n"
224
+ "Define functions at the module level or use @action."
225
+ ),
226
+ "yield_statement": (
227
+ "Yield statements are not supported in workflow code.\n"
228
+ "Workflows must return a complete result, not generate values incrementally."
229
+ ),
230
+ "unsupported_expression": (
231
+ "This expression type is not supported in workflow code.\n"
232
+ "Move the logic into an @action or rewrite using supported expressions."
233
+ ),
234
+ "unsupported_literal": (
235
+ "This literal type is not supported in workflow code.\n"
236
+ "Convert the value to a supported literal type inside an @action."
237
+ ),
238
+ }
239
+
240
+
241
+ if TYPE_CHECKING:
242
+ from .workflow import Workflow
243
+
244
+
245
+ @dataclass
246
+ class ActionDefinition:
247
+ """Definition of an action function."""
248
+
249
+ action_name: str
250
+ module_name: Optional[str]
251
+ signature: inspect.Signature
252
+
253
+
254
+ @dataclass
255
+ class ModuleContext:
256
+ """Cached IRBuilder context derived from a module."""
257
+
258
+ action_defs: Dict[str, ActionDefinition]
259
+ imported_names: Dict[str, "ImportedName"]
260
+ module_functions: Set[str]
261
+ model_defs: Dict[str, "ModelDefinition"]
262
+
263
+
264
+ @dataclass
265
+ class TransformContext:
266
+ """Context for IR transformations."""
267
+
268
+ # Counter for generating unique function names
269
+ implicit_fn_counter: int = 0
270
+ # Implicit functions generated during transformation
271
+ implicit_functions: List[ir.FunctionDef] = None # type: ignore
272
+
273
+ def __post_init__(self) -> None:
274
+ if self.implicit_functions is None:
275
+ self.implicit_functions = []
276
+
277
+ def next_implicit_fn_name(self, prefix: str = "implicit") -> str:
278
+ """Generate a unique implicit function name."""
279
+ self.implicit_fn_counter += 1
280
+ return f"__{prefix}_{self.implicit_fn_counter}__"
281
+
282
+
283
+ def build_workflow_ir(workflow_cls: type["Workflow"]) -> ir.Program:
284
+ """Build an IR Program from a workflow class.
285
+
286
+ Args:
287
+ workflow_cls: The workflow class to convert.
288
+
289
+ Returns:
290
+ An IR Program proto message.
291
+ """
292
+ original_run = getattr(workflow_cls, "__workflow_run_impl__", None)
293
+ if original_run is None:
294
+ original_run = workflow_cls.__dict__.get("run")
295
+ if original_run is None:
296
+ raise ValueError(f"workflow {workflow_cls!r} missing run implementation")
297
+
298
+ module = inspect.getmodule(original_run)
299
+ if module is None:
300
+ raise ValueError(f"unable to locate module for workflow {workflow_cls!r}")
301
+
302
+ module_contexts: Dict[str, ModuleContext] = {}
303
+
304
+ def get_module_context(target_module: Any) -> ModuleContext:
305
+ module_name = target_module.__name__
306
+ if module_name not in module_contexts:
307
+ module_contexts[module_name] = ModuleContext(
308
+ action_defs=_discover_action_names(target_module),
309
+ imported_names=_discover_module_imports(target_module),
310
+ module_functions=_discover_module_functions(target_module),
311
+ model_defs=_discover_model_definitions(target_module),
312
+ )
313
+ return module_contexts[module_name]
314
+
315
+ # Build the IR with transformation context
316
+ ctx = TransformContext()
317
+ program = ir.Program()
318
+ function_defs: Dict[str, ir.FunctionDef] = {}
319
+
320
+ def parse_function(fn: Any) -> tuple[ast.AST, Optional[str], int]:
321
+ source_lines, start_line = inspect.getsourcelines(fn)
322
+ function_source = textwrap.dedent("".join(source_lines))
323
+ filename = inspect.getsourcefile(fn)
324
+ if filename is None:
325
+ filename = inspect.getfile(fn)
326
+ return ast.parse(function_source, filename=filename or "<unknown>"), filename, start_line
327
+
328
+ def _with_source_location(
329
+ err: UnsupportedPatternError,
330
+ filename: Optional[str],
331
+ start_line: int,
332
+ ) -> UnsupportedPatternError:
333
+ line = err.line
334
+ col = err.col
335
+ if line is not None:
336
+ line = start_line + line - 1
337
+ if col is not None:
338
+ col = col + 1
339
+ return UnsupportedPatternError(
340
+ err.message,
341
+ err.recommendation,
342
+ line=line,
343
+ col=col,
344
+ filename=filename,
345
+ )
346
+
347
+ def add_function_def(
348
+ fn: Any,
349
+ fn_tree: ast.AST,
350
+ filename: Optional[str],
351
+ start_line: int,
352
+ override_name: Optional[str] = None,
353
+ ) -> None:
354
+ fn_module = inspect.getmodule(fn)
355
+ if fn_module is None:
356
+ raise ValueError(f"unable to locate module for function {fn!r}")
357
+
358
+ ctx_data = get_module_context(fn_module)
359
+ builder = IRBuilder(
360
+ ctx_data.action_defs,
361
+ ctx,
362
+ ctx_data.imported_names,
363
+ ctx_data.module_functions,
364
+ ctx_data.model_defs,
365
+ )
366
+ try:
367
+ builder.visit(fn_tree)
368
+ except UnsupportedPatternError as err:
369
+ raise _with_source_location(err, filename, start_line) from err
370
+ if builder.function_def:
371
+ if override_name:
372
+ builder.function_def.name = override_name
373
+ function_defs[builder.function_def.name] = builder.function_def
374
+
375
+ # Build the entry function from run() and map it to main in the IR.
376
+ run_tree, run_filename, run_start_line = parse_function(original_run)
377
+ add_function_def(original_run, run_tree, run_filename, run_start_line, "main")
378
+
379
+ # Include helper methods reachable via self.method() calls
380
+ pending = list(_collect_self_method_calls(run_tree))
381
+ visited: Set[str] = set()
382
+ skip_methods = {"run_action"}
383
+
384
+ while pending:
385
+ method_name = pending.pop()
386
+ if method_name in visited or method_name == "run" or method_name in skip_methods:
387
+ continue
388
+ visited.add(method_name)
389
+
390
+ method = _find_workflow_method(workflow_cls, method_name)
391
+ if method is None:
392
+ continue
393
+
394
+ method_tree, method_filename, method_start_line = parse_function(method)
395
+ add_function_def(method, method_tree, method_filename, method_start_line)
396
+
397
+ pending.extend(_collect_self_method_calls(method_tree))
398
+
399
+ # Add implicit functions first (they may be called by the main function)
400
+ for implicit_fn in ctx.implicit_functions:
401
+ program.functions.append(implicit_fn)
402
+
403
+ # Add all function definitions (run + reachable helper methods)
404
+ for fn_def in function_defs.values():
405
+ program.functions.append(fn_def)
406
+
407
+ return program
408
+
409
+
410
+ def _collect_self_method_calls(tree: ast.AST) -> Set[str]:
411
+ """Collect self.method(...) call names from a parsed function AST."""
412
+ calls: Set[str] = set()
413
+ for node in ast.walk(tree):
414
+ if not isinstance(node, ast.Call):
415
+ continue
416
+ func = node.func
417
+ if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
418
+ if func.value.id == "self":
419
+ calls.add(func.attr)
420
+ return calls
421
+
422
+
423
+ def _find_workflow_method(workflow_cls: type["Workflow"], name: str) -> Optional[Any]:
424
+ """Find a workflow method by name across the class MRO."""
425
+ for base in workflow_cls.__mro__:
426
+ if name not in base.__dict__:
427
+ continue
428
+ value = base.__dict__[name]
429
+ if isinstance(value, staticmethod) or isinstance(value, classmethod):
430
+ return value.__func__
431
+ if inspect.isfunction(value):
432
+ return value
433
+ return None
434
+
435
+
436
+ def _discover_action_names(module: Any) -> Dict[str, ActionDefinition]:
437
+ """Discover all @action decorated functions in a module."""
438
+ names: Dict[str, ActionDefinition] = {}
439
+ for attr_name in dir(module):
440
+ attr = getattr(module, attr_name)
441
+ action_name = getattr(attr, "__rappel_action_name__", None)
442
+ action_module = getattr(attr, "__rappel_action_module__", None)
443
+ if callable(attr) and action_name:
444
+ signature = inspect.signature(attr)
445
+ names[attr_name] = ActionDefinition(
446
+ action_name=action_name,
447
+ module_name=action_module or module.__name__,
448
+ signature=signature,
449
+ )
450
+ return names
451
+
452
+
453
+ def _discover_module_functions(module: Any) -> Set[str]:
454
+ """Discover all async function names defined in a module.
455
+
456
+ This is used to detect when users await functions in the same module
457
+ that are NOT decorated with @action.
458
+ """
459
+ function_names: Set[str] = set()
460
+ for attr_name in dir(module):
461
+ try:
462
+ attr = getattr(module, attr_name)
463
+ except AttributeError:
464
+ continue
465
+
466
+ # Only include functions defined in THIS module (not imported)
467
+ if not callable(attr):
468
+ continue
469
+ if not inspect.iscoroutinefunction(attr):
470
+ continue
471
+
472
+ # Check if the function is defined in this module
473
+ func_module = getattr(attr, "__module__", None)
474
+ if func_module == module.__name__:
475
+ function_names.add(attr_name)
476
+
477
+ return function_names
478
+
479
+
480
+ @dataclass
481
+ class ImportedName:
482
+ """Tracks an imported name and its source module."""
483
+
484
+ local_name: str # Name used in code (e.g., "sleep")
485
+ module: str # Source module (e.g., "asyncio")
486
+ original_name: str # Original name in source module (e.g., "sleep")
487
+
488
+
489
+ @dataclass
490
+ class ModelFieldDefinition:
491
+ """Definition of a field in a Pydantic model or dataclass."""
492
+
493
+ name: str
494
+ has_default: bool
495
+ default_value: Any = None # Only set if has_default is True
496
+
497
+
498
+ @dataclass
499
+ class ModelDefinition:
500
+ """Definition of a Pydantic model or dataclass that can be used in workflows.
501
+
502
+ These are data classes that can be instantiated in workflow code and will
503
+ be converted to dictionary expressions in the IR.
504
+ """
505
+
506
+ class_name: str
507
+ fields: Dict[str, ModelFieldDefinition]
508
+ is_pydantic: bool # True for Pydantic models, False for dataclasses
509
+
510
+
511
+ def _is_simple_pydantic_model(cls: type) -> bool:
512
+ """Check if a class is a simple Pydantic model without custom validators.
513
+
514
+ A simple Pydantic model:
515
+ - Inherits from pydantic.BaseModel
516
+ - Has no field_validator or model_validator decorators
517
+ - Has no custom __init__ method
518
+
519
+ Returns False if pydantic is not installed or cls is not a Pydantic model.
520
+ """
521
+ try:
522
+ from pydantic import BaseModel
523
+ except ImportError:
524
+ return False
525
+
526
+ if not isinstance(cls, type) or not issubclass(cls, BaseModel):
527
+ return False
528
+
529
+ # Check for validators - Pydantic v2 uses __pydantic_decorators__
530
+ decorators = getattr(cls, "__pydantic_decorators__", None)
531
+ if decorators is not None:
532
+ # Check for field validators
533
+ if hasattr(decorators, "field_validators") and decorators.field_validators:
534
+ return False
535
+ # Check for model validators
536
+ if hasattr(decorators, "model_validators") and decorators.model_validators:
537
+ return False
538
+
539
+ # Check for custom __init__ (not the one from BaseModel)
540
+ if "__init__" in cls.__dict__:
541
+ return False
542
+
543
+ return True
544
+
545
+
546
+ def _is_simple_dataclass(cls: type) -> bool:
547
+ """Check if a class is a simple dataclass without custom logic.
548
+
549
+ A simple dataclass:
550
+ - Is decorated with @dataclass
551
+ - Has no custom __init__ method (uses the generated one)
552
+ - Has no __post_init__ method
553
+
554
+ Returns False if cls is not a dataclass.
555
+ """
556
+ import dataclasses
557
+
558
+ if not dataclasses.is_dataclass(cls):
559
+ return False
560
+
561
+ # Check for __post_init__ which could have custom logic
562
+ if hasattr(cls, "__post_init__") and "__post_init__" in cls.__dict__:
563
+ return False
564
+
565
+ # Dataclasses generate __init__, so check if there's a custom one
566
+ # that overrides it (unlikely but possible)
567
+ # The dataclass decorator sets __init__ so we can't easily detect override
568
+ # We'll trust that dataclasses without __post_init__ are simple
569
+
570
+ return True
571
+
572
+
573
+ def _get_pydantic_model_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
574
+ """Extract field definitions from a Pydantic model."""
575
+ try:
576
+ from pydantic import BaseModel
577
+ from pydantic.fields import FieldInfo
578
+ except ImportError:
579
+ return {}
580
+
581
+ if not issubclass(cls, BaseModel):
582
+ return {}
583
+
584
+ fields: Dict[str, ModelFieldDefinition] = {}
585
+
586
+ # Pydantic v2 uses model_fields
587
+ model_fields = getattr(cls, "model_fields", {})
588
+ for field_name, field_info in model_fields.items():
589
+ has_default = False
590
+ default_value = None
591
+
592
+ if isinstance(field_info, FieldInfo):
593
+ # Check if field has a default value
594
+ # PydanticUndefined means no default
595
+ from pydantic_core import PydanticUndefined
596
+
597
+ if field_info.default is not PydanticUndefined:
598
+ has_default = True
599
+ default_value = field_info.default
600
+ elif field_info.default_factory is not None:
601
+ # We can't serialize factory functions, so treat as no default
602
+ has_default = False
603
+
604
+ fields[field_name] = ModelFieldDefinition(
605
+ name=field_name,
606
+ has_default=has_default,
607
+ default_value=default_value,
608
+ )
609
+
610
+ return fields
611
+
612
+
613
+ def _get_dataclass_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
614
+ """Extract field definitions from a dataclass."""
615
+ import dataclasses
616
+
617
+ if not dataclasses.is_dataclass(cls):
618
+ return {}
619
+
620
+ fields: Dict[str, ModelFieldDefinition] = {}
621
+ for field in dataclasses.fields(cls):
622
+ has_default = False
623
+ default_value = None
624
+
625
+ if field.default is not dataclasses.MISSING:
626
+ has_default = True
627
+ default_value = field.default
628
+ elif field.default_factory is not dataclasses.MISSING:
629
+ # We can't serialize factory functions, so treat as no default
630
+ has_default = False
631
+
632
+ fields[field.name] = ModelFieldDefinition(
633
+ name=field.name,
634
+ has_default=has_default,
635
+ default_value=default_value,
636
+ )
637
+
638
+ return fields
639
+
640
+
641
+ def _discover_model_definitions(module: Any) -> Dict[str, ModelDefinition]:
642
+ """Discover all Pydantic models and dataclasses that can be used in workflows.
643
+
644
+ Only discovers "simple" models without custom validators or __post_init__.
645
+ """
646
+ models: Dict[str, ModelDefinition] = {}
647
+
648
+ for attr_name in dir(module):
649
+ try:
650
+ attr = getattr(module, attr_name)
651
+ except AttributeError:
652
+ continue
653
+
654
+ if not isinstance(attr, type):
655
+ continue
656
+
657
+ # Check if this class is defined in this module or imported
658
+ # We want to include both, as models might be imported
659
+ if _is_simple_pydantic_model(attr):
660
+ fields = _get_pydantic_model_fields(attr)
661
+ models[attr_name] = ModelDefinition(
662
+ class_name=attr_name,
663
+ fields=fields,
664
+ is_pydantic=True,
665
+ )
666
+ elif _is_simple_dataclass(attr):
667
+ fields = _get_dataclass_fields(attr)
668
+ models[attr_name] = ModelDefinition(
669
+ class_name=attr_name,
670
+ fields=fields,
671
+ is_pydantic=False,
672
+ )
673
+
674
+ return models
675
+
676
+
677
+ def _discover_module_imports(module: Any) -> Dict[str, ImportedName]:
678
+ """Discover imports in a module by parsing its source.
679
+
680
+ Tracks imports like:
681
+ - from asyncio import sleep -> {"sleep": ImportedName("sleep", "asyncio", "sleep")}
682
+ - from asyncio import sleep as s -> {"s": ImportedName("s", "asyncio", "sleep")}
683
+ """
684
+ imported: Dict[str, ImportedName] = {}
685
+
686
+ try:
687
+ source = inspect.getsource(module)
688
+ tree = ast.parse(source)
689
+ except (OSError, TypeError):
690
+ # Can't get source (e.g., built-in module)
691
+ return imported
692
+
693
+ for node in ast.walk(tree):
694
+ if isinstance(node, ast.ImportFrom) and node.module:
695
+ for alias in node.names:
696
+ local_name = alias.asname if alias.asname else alias.name
697
+ imported[local_name] = ImportedName(
698
+ local_name=local_name,
699
+ module=node.module,
700
+ original_name=alias.name,
701
+ )
702
+
703
+ return imported
704
+
705
+
706
+ class IRBuilder(ast.NodeVisitor):
707
+ """Builds IR from Python AST with deep transformations."""
708
+
709
+ def __init__(
710
+ self,
711
+ action_defs: Dict[str, ActionDefinition],
712
+ ctx: TransformContext,
713
+ imported_names: Optional[Dict[str, ImportedName]] = None,
714
+ module_functions: Optional[Set[str]] = None,
715
+ model_defs: Optional[Dict[str, ModelDefinition]] = None,
716
+ ):
717
+ self._action_defs = action_defs
718
+ self._ctx = ctx
719
+ self._imported_names = imported_names or {}
720
+ self._module_functions = module_functions or set()
721
+ self._model_defs = model_defs or {}
722
+ self.function_def: Optional[ir.FunctionDef] = None
723
+ self._statements: List[ir.Statement] = []
724
+
725
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
726
+ """Visit a function definition (the workflow's run method)."""
727
+ inputs = self._collect_function_inputs(node)
728
+
729
+ # Create the function definition
730
+ self.function_def = ir.FunctionDef(
731
+ name=node.name,
732
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
733
+ span=_make_span(node),
734
+ )
735
+
736
+ # Visit the body - _visit_statement now returns a list
737
+ self._statements = []
738
+ for stmt in node.body:
739
+ ir_stmts = self._visit_statement(stmt)
740
+ self._statements.extend(ir_stmts)
741
+
742
+ # Set the body
743
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
744
+
745
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
746
+ """Visit an async function definition (the workflow's run method)."""
747
+ # Handle async the same way as sync for IR building
748
+ inputs = self._collect_function_inputs(node)
749
+
750
+ self.function_def = ir.FunctionDef(
751
+ name=node.name,
752
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
753
+ span=_make_span(node),
754
+ )
755
+
756
+ self._statements = []
757
+ for stmt in node.body:
758
+ ir_stmts = self._visit_statement(stmt)
759
+ self._statements.extend(ir_stmts)
760
+
761
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
762
+
763
+ def _visit_statement(self, node: ast.stmt) -> List[ir.Statement]:
764
+ """Convert a Python statement to IR Statement(s).
765
+
766
+ Returns a list because some transformations (like try block hoisting)
767
+ may expand a single Python statement into multiple IR statements.
768
+
769
+ Raises UnsupportedPatternError for unsupported statement types.
770
+ """
771
+ if isinstance(node, ast.Assign):
772
+ dict_expanded = self._expand_dict_comprehension_assignment(node)
773
+ if dict_expanded is not None:
774
+ return dict_expanded
775
+ expanded = self._expand_list_comprehension_assignment(node)
776
+ if expanded is not None:
777
+ return expanded
778
+ result = self._visit_assign(node)
779
+ return [result] if result else []
780
+ elif isinstance(node, ast.Expr):
781
+ result = self._visit_expr_stmt(node)
782
+ return [result] if result else []
783
+ elif isinstance(node, ast.For):
784
+ return self._visit_for(node)
785
+ elif isinstance(node, ast.If):
786
+ return self._visit_if(node)
787
+ elif isinstance(node, ast.Try):
788
+ return self._visit_try(node)
789
+ elif isinstance(node, ast.Return):
790
+ return self._visit_return(node)
791
+ elif isinstance(node, ast.AugAssign):
792
+ return self._visit_aug_assign(node)
793
+ elif isinstance(node, ast.Pass):
794
+ # Pass statements are fine, they just don't produce IR
795
+ return []
796
+
797
+ # Check for unsupported statement types
798
+ self._check_unsupported_statement(node)
799
+
800
+ return []
801
+
802
+ def _check_unsupported_statement(self, node: ast.stmt) -> None:
803
+ """Check for unsupported statement types and raise descriptive errors."""
804
+ line = getattr(node, "lineno", None)
805
+ col = getattr(node, "col_offset", None)
806
+
807
+ if isinstance(node, ast.While):
808
+ raise UnsupportedPatternError(
809
+ "While loops are not supported",
810
+ RECOMMENDATIONS["while_loop"],
811
+ line=line,
812
+ col=col,
813
+ )
814
+ elif isinstance(node, (ast.With, ast.AsyncWith)):
815
+ raise UnsupportedPatternError(
816
+ "Context managers (with statements) are not supported",
817
+ RECOMMENDATIONS["with_statement"],
818
+ line=line,
819
+ col=col,
820
+ )
821
+ elif isinstance(node, ast.Raise):
822
+ raise UnsupportedPatternError(
823
+ "The 'raise' statement is not supported",
824
+ RECOMMENDATIONS["raise_statement"],
825
+ line=line,
826
+ col=col,
827
+ )
828
+ elif isinstance(node, ast.Assert):
829
+ raise UnsupportedPatternError(
830
+ "Assert statements are not supported",
831
+ RECOMMENDATIONS["assert_statement"],
832
+ line=line,
833
+ col=col,
834
+ )
835
+ elif isinstance(node, ast.Delete):
836
+ raise UnsupportedPatternError(
837
+ "The 'del' statement is not supported",
838
+ RECOMMENDATIONS["delete"],
839
+ line=line,
840
+ col=col,
841
+ )
842
+ elif isinstance(node, ast.Global):
843
+ raise UnsupportedPatternError(
844
+ "Global statements are not supported",
845
+ RECOMMENDATIONS["global_statement"],
846
+ line=line,
847
+ col=col,
848
+ )
849
+ elif isinstance(node, ast.Nonlocal):
850
+ raise UnsupportedPatternError(
851
+ "Nonlocal statements are not supported",
852
+ RECOMMENDATIONS["nonlocal_statement"],
853
+ line=line,
854
+ col=col,
855
+ )
856
+ elif isinstance(node, (ast.Import, ast.ImportFrom)):
857
+ raise UnsupportedPatternError(
858
+ "Import statements inside workflow run() are not supported",
859
+ RECOMMENDATIONS["import_statement"],
860
+ line=line,
861
+ col=col,
862
+ )
863
+ elif isinstance(node, ast.ClassDef):
864
+ raise UnsupportedPatternError(
865
+ "Class definitions inside workflow run() are not supported",
866
+ RECOMMENDATIONS["class_def"],
867
+ line=line,
868
+ col=col,
869
+ )
870
+ elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
871
+ raise UnsupportedPatternError(
872
+ "Nested function definitions are not supported",
873
+ RECOMMENDATIONS["function_def"],
874
+ line=line,
875
+ col=col,
876
+ )
877
+ elif hasattr(ast, "Match") and isinstance(node, ast.Match):
878
+ raise UnsupportedPatternError(
879
+ "Match statements are not supported",
880
+ RECOMMENDATIONS["match"],
881
+ line=line,
882
+ col=col,
883
+ )
884
+
885
+ def _expand_list_comprehension_assignment(
886
+ self, node: ast.Assign
887
+ ) -> Optional[List[ir.Statement]]:
888
+ """Expand a list comprehension assignment into loop-based statements.
889
+
890
+ Example:
891
+ active_users = [user for user in users if user.active]
892
+
893
+ Becomes:
894
+ active_users = []
895
+ for user in users:
896
+ if user.active:
897
+ active_users = active_users + [user]
898
+ """
899
+ if not isinstance(node.value, ast.ListComp):
900
+ return None
901
+
902
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
903
+ line = getattr(node, "lineno", None)
904
+ col = getattr(node, "col_offset", None)
905
+ raise UnsupportedPatternError(
906
+ "List comprehension assignments must target a single variable",
907
+ "Assign the comprehension to a simple variable like `results = [x for x in items]`",
908
+ line=line,
909
+ col=col,
910
+ )
911
+
912
+ listcomp = node.value
913
+ if len(listcomp.generators) != 1:
914
+ line = getattr(listcomp, "lineno", None)
915
+ col = getattr(listcomp, "col_offset", None)
916
+ raise UnsupportedPatternError(
917
+ "List comprehensions with multiple generators are not supported",
918
+ "Use nested for loops instead of combining multiple generators in one comprehension",
919
+ line=line,
920
+ col=col,
921
+ )
922
+
923
+ gen = listcomp.generators[0]
924
+ if gen.is_async:
925
+ line = getattr(listcomp, "lineno", None)
926
+ col = getattr(listcomp, "col_offset", None)
927
+ raise UnsupportedPatternError(
928
+ "Async list comprehensions are not supported",
929
+ "Rewrite using an explicit async for loop",
930
+ line=line,
931
+ col=col,
932
+ )
933
+
934
+ target_name = node.targets[0].id
935
+
936
+ # Initialize the accumulator list: active_users = []
937
+ init_assign_ast = ast.Assign(
938
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
939
+ value=ast.List(elts=[], ctx=ast.Load()),
940
+ type_comment=None,
941
+ )
942
+ ast.copy_location(init_assign_ast, node)
943
+ ast.fix_missing_locations(init_assign_ast)
944
+
945
+ def _make_append_assignment(value_expr: ast.expr) -> ast.Assign:
946
+ append_assign = ast.Assign(
947
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
948
+ value=ast.BinOp(
949
+ left=ast.Name(id=target_name, ctx=ast.Load()),
950
+ op=ast.Add(),
951
+ right=ast.List(elts=[copy.deepcopy(value_expr)], ctx=ast.Load()),
952
+ ),
953
+ type_comment=None,
954
+ )
955
+ ast.copy_location(append_assign, node.value)
956
+ ast.fix_missing_locations(append_assign)
957
+ return append_assign
958
+
959
+ append_statements: List[ast.stmt] = []
960
+ if isinstance(listcomp.elt, ast.IfExp):
961
+ then_assign = _make_append_assignment(listcomp.elt.body)
962
+ else_assign = _make_append_assignment(listcomp.elt.orelse)
963
+ branch_if = ast.If(
964
+ test=copy.deepcopy(listcomp.elt.test),
965
+ body=[then_assign],
966
+ orelse=[else_assign],
967
+ )
968
+ ast.copy_location(branch_if, listcomp.elt)
969
+ ast.fix_missing_locations(branch_if)
970
+ append_statements.append(branch_if)
971
+ else:
972
+ append_statements.append(_make_append_assignment(listcomp.elt))
973
+
974
+ loop_body: List[ast.stmt] = append_statements
975
+ if gen.ifs:
976
+ condition: ast.expr
977
+ if len(gen.ifs) == 1:
978
+ condition = copy.deepcopy(gen.ifs[0])
979
+ else:
980
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
981
+ ast.copy_location(condition, gen.ifs[0])
982
+ if_stmt = ast.If(test=condition, body=append_statements, orelse=[])
983
+ ast.copy_location(if_stmt, gen.ifs[0])
984
+ ast.fix_missing_locations(if_stmt)
985
+ loop_body = [if_stmt]
986
+
987
+ loop_ast = ast.For(
988
+ target=copy.deepcopy(gen.target),
989
+ iter=copy.deepcopy(gen.iter),
990
+ body=loop_body,
991
+ orelse=[],
992
+ type_comment=None,
993
+ )
994
+ ast.copy_location(loop_ast, node)
995
+ ast.fix_missing_locations(loop_ast)
996
+
997
+ statements: List[ir.Statement] = []
998
+ init_stmt = self._visit_assign(init_assign_ast)
999
+ if init_stmt:
1000
+ statements.append(init_stmt)
1001
+ statements.extend(self._visit_for(loop_ast))
1002
+
1003
+ return statements
1004
+
1005
+ def _expand_dict_comprehension_assignment(
1006
+ self, node: ast.Assign
1007
+ ) -> Optional[List[ir.Statement]]:
1008
+ """Expand a dict comprehension assignment into loop-based statements."""
1009
+ if not isinstance(node.value, ast.DictComp):
1010
+ return None
1011
+
1012
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
1013
+ line = getattr(node, "lineno", None)
1014
+ col = getattr(node, "col_offset", None)
1015
+ raise UnsupportedPatternError(
1016
+ "Dict comprehension assignments must target a single variable",
1017
+ "Assign the comprehension to a simple variable like `result = {k: v for k, v in pairs}`",
1018
+ line=line,
1019
+ col=col,
1020
+ )
1021
+
1022
+ dictcomp = node.value
1023
+ if len(dictcomp.generators) != 1:
1024
+ line = getattr(dictcomp, "lineno", None)
1025
+ col = getattr(dictcomp, "col_offset", None)
1026
+ raise UnsupportedPatternError(
1027
+ "Dict comprehensions with multiple generators are not supported",
1028
+ "Use nested for loops instead of combining multiple generators in one comprehension",
1029
+ line=line,
1030
+ col=col,
1031
+ )
1032
+
1033
+ gen = dictcomp.generators[0]
1034
+ if gen.is_async:
1035
+ line = getattr(dictcomp, "lineno", None)
1036
+ col = getattr(dictcomp, "col_offset", None)
1037
+ raise UnsupportedPatternError(
1038
+ "Async dict comprehensions are not supported",
1039
+ "Rewrite using an explicit async for loop",
1040
+ line=line,
1041
+ col=col,
1042
+ )
1043
+
1044
+ target_name = node.targets[0].id
1045
+
1046
+ # Initialize accumulator: result = {}
1047
+ init_assign_ast = ast.Assign(
1048
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
1049
+ value=ast.Dict(keys=[], values=[]),
1050
+ type_comment=None,
1051
+ )
1052
+ ast.copy_location(init_assign_ast, node)
1053
+ ast.fix_missing_locations(init_assign_ast)
1054
+
1055
+ # result[key] = value
1056
+ subscript_target = ast.Subscript(
1057
+ value=ast.Name(id=target_name, ctx=ast.Load()),
1058
+ slice=copy.deepcopy(dictcomp.key),
1059
+ ctx=ast.Store(),
1060
+ )
1061
+ append_assign_ast = ast.Assign(
1062
+ targets=[subscript_target],
1063
+ value=copy.deepcopy(dictcomp.value),
1064
+ type_comment=None,
1065
+ )
1066
+ ast.copy_location(append_assign_ast, node.value)
1067
+ ast.fix_missing_locations(append_assign_ast)
1068
+
1069
+ loop_body: List[ast.stmt] = []
1070
+ if gen.ifs:
1071
+ condition: ast.expr
1072
+ if len(gen.ifs) == 1:
1073
+ condition = copy.deepcopy(gen.ifs[0])
1074
+ else:
1075
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
1076
+ ast.copy_location(condition, gen.ifs[0])
1077
+ if_stmt = ast.If(test=condition, body=[append_assign_ast], orelse=[])
1078
+ ast.copy_location(if_stmt, gen.ifs[0])
1079
+ ast.fix_missing_locations(if_stmt)
1080
+ loop_body.append(if_stmt)
1081
+ else:
1082
+ loop_body.append(append_assign_ast)
1083
+
1084
+ loop_ast = ast.For(
1085
+ target=copy.deepcopy(gen.target),
1086
+ iter=copy.deepcopy(gen.iter),
1087
+ body=loop_body,
1088
+ orelse=[],
1089
+ type_comment=None,
1090
+ )
1091
+ ast.copy_location(loop_ast, node)
1092
+ ast.fix_missing_locations(loop_ast)
1093
+
1094
+ statements: List[ir.Statement] = []
1095
+ init_stmt = self._visit_assign(init_assign_ast)
1096
+ if init_stmt:
1097
+ statements.append(init_stmt)
1098
+ statements.extend(self._visit_for(loop_ast))
1099
+
1100
+ return statements
1101
+
1102
+ def _visit_assign(self, node: ast.Assign) -> Optional[ir.Statement]:
1103
+ """Convert assignment to IR.
1104
+
1105
+ All assignments with targets use the Assignment statement type.
1106
+ This provides uniform unpacking support for:
1107
+ - Action calls: a, b = @get_pair()
1108
+ - Parallel blocks: a, b = parallel: @x() @y()
1109
+ - Regular expressions: a, b = some_list
1110
+
1111
+ Raises UnsupportedPatternError for:
1112
+ - Constructor calls: x = MyClass(...)
1113
+ - Non-action await: x = await some_func()
1114
+ """
1115
+ stmt = ir.Statement(span=_make_span(node))
1116
+ targets = self._get_assign_targets(node.targets)
1117
+
1118
+ # Check for Pydantic model or dataclass constructor calls
1119
+ # These are converted to dict expressions
1120
+ model_name = self._is_model_constructor(node.value)
1121
+ if model_name and isinstance(node.value, ast.Call):
1122
+ value_expr = self._convert_model_constructor_to_dict(node.value, model_name)
1123
+ assign = ir.Assignment(targets=targets, value=value_expr)
1124
+ stmt.assignment.CopyFrom(assign)
1125
+ return stmt
1126
+
1127
+ # Check for constructor calls in assignment (e.g., x = MyModel(...))
1128
+ # This must come AFTER the model constructor check since models are allowed
1129
+ self._check_constructor_in_assignment(node.value)
1130
+
1131
+ # Check for asyncio.gather() - convert to parallel or spread expression
1132
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1133
+ gather_call = node.value.value
1134
+ if self._is_asyncio_gather_call(gather_call):
1135
+ gather_result = self._convert_asyncio_gather(gather_call)
1136
+ if gather_result is not None:
1137
+ if isinstance(gather_result, ir.ParallelExpr):
1138
+ value = ir.Expr(parallel_expr=gather_result, span=_make_span(node))
1139
+ else:
1140
+ # SpreadExpr
1141
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1142
+ assign = ir.Assignment(targets=targets, value=value)
1143
+ stmt.assignment.CopyFrom(assign)
1144
+ return stmt
1145
+
1146
+ # Check if this is an action call - wrap in Assignment for uniform unpacking
1147
+ action_call = self._extract_action_call(node.value)
1148
+ if action_call:
1149
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1150
+ assign = ir.Assignment(targets=targets, value=value)
1151
+ stmt.assignment.CopyFrom(assign)
1152
+ return stmt
1153
+
1154
+ # Regular assignment (variables, literals, expressions)
1155
+ value_expr = _expr_to_ir(node.value)
1156
+ if value_expr:
1157
+ assign = ir.Assignment(targets=targets, value=value_expr)
1158
+ stmt.assignment.CopyFrom(assign)
1159
+ return stmt
1160
+
1161
+ return None
1162
+
1163
+ def _visit_expr_stmt(self, node: ast.Expr) -> Optional[ir.Statement]:
1164
+ """Convert expression statement to IR (side effect only, no assignment)."""
1165
+ stmt = ir.Statement(span=_make_span(node))
1166
+
1167
+ # Check for asyncio.gather() - convert to parallel block statement (side effect)
1168
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1169
+ gather_call = node.value.value
1170
+ if self._is_asyncio_gather_call(gather_call):
1171
+ gather_result = self._convert_asyncio_gather(gather_call)
1172
+ if gather_result is not None:
1173
+ if isinstance(gather_result, ir.ParallelExpr):
1174
+ # Side effect only - use ParallelBlock statement
1175
+ parallel = ir.ParallelBlock()
1176
+ parallel.calls.extend(gather_result.calls)
1177
+ stmt.parallel_block.CopyFrom(parallel)
1178
+ return stmt
1179
+ else:
1180
+ # SpreadExpr as side effect - wrap in assignment with no targets
1181
+ # This handles: await asyncio.gather(*[action(x) for x in items])
1182
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1183
+ assign = ir.Assignment(targets=[], value=value)
1184
+ stmt.assignment.CopyFrom(assign)
1185
+ return stmt
1186
+
1187
+ # Check if this is an action call (side effect only)
1188
+ action_call = self._extract_action_call(node.value)
1189
+ if action_call:
1190
+ stmt.action_call.CopyFrom(action_call)
1191
+ return stmt
1192
+
1193
+ # Convert list.append(x) to list = list + [x]
1194
+ # This makes the mutation explicit so data flows correctly through the DAG
1195
+ if isinstance(node.value, ast.Call):
1196
+ call = node.value
1197
+ if (
1198
+ isinstance(call.func, ast.Attribute)
1199
+ and call.func.attr == "append"
1200
+ and isinstance(call.func.value, ast.Name)
1201
+ and len(call.args) == 1
1202
+ ):
1203
+ list_name = call.func.value.id
1204
+ append_value = call.args[0]
1205
+ # Create: list = list + [value]
1206
+ list_var = ir.Expr(variable=ir.Variable(name=list_name), span=_make_span(node))
1207
+ value_expr = _expr_to_ir(append_value)
1208
+ if value_expr:
1209
+ # Create [value] as a list literal
1210
+ list_literal = ir.Expr(
1211
+ list=ir.ListExpr(elements=[value_expr]), span=_make_span(node)
1212
+ )
1213
+ # Create list + [value]
1214
+ concat_expr = ir.Expr(
1215
+ binary_op=ir.BinaryOp(
1216
+ op=ir.BinaryOperator.BINARY_OP_ADD, left=list_var, right=list_literal
1217
+ ),
1218
+ span=_make_span(node),
1219
+ )
1220
+ assign = ir.Assignment(targets=[list_name], value=concat_expr)
1221
+ stmt.assignment.CopyFrom(assign)
1222
+ return stmt
1223
+
1224
+ # Regular expression
1225
+ expr = _expr_to_ir(node.value)
1226
+ if expr:
1227
+ stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=expr))
1228
+ return stmt
1229
+
1230
+ return None
1231
+
1232
+ def _visit_for(self, node: ast.For) -> List[ir.Statement]:
1233
+ """Convert for loop to IR.
1234
+
1235
+ The loop body is emitted as a full block so it can contain multiple
1236
+ statements/calls and early `return`.
1237
+ """
1238
+ # Get loop variables
1239
+ loop_vars: List[str] = []
1240
+ if isinstance(node.target, ast.Name):
1241
+ loop_vars.append(node.target.id)
1242
+ elif isinstance(node.target, ast.Tuple):
1243
+ for elt in node.target.elts:
1244
+ if isinstance(elt, ast.Name):
1245
+ loop_vars.append(elt.id)
1246
+
1247
+ # Get iterable
1248
+ iterable = _expr_to_ir(node.iter)
1249
+ if not iterable:
1250
+ return []
1251
+
1252
+ # Build body statements (recursively transforms nested structures)
1253
+ body_stmts: List[ir.Statement] = []
1254
+ for body_node in node.body:
1255
+ stmts = self._visit_statement(body_node)
1256
+ body_stmts.extend(stmts)
1257
+
1258
+ stmt = ir.Statement(span=_make_span(node))
1259
+ for_loop = ir.ForLoop(
1260
+ loop_vars=loop_vars,
1261
+ iterable=iterable,
1262
+ block_body=ir.Block(statements=body_stmts, span=_make_span(node)),
1263
+ )
1264
+ stmt.for_loop.CopyFrom(for_loop)
1265
+ return [stmt]
1266
+
1267
+ def _detect_accumulator_targets(
1268
+ self, stmts: List[ir.Statement], in_scope_vars: set
1269
+ ) -> List[str]:
1270
+ """Detect out-of-scope variable modifications in for loop body.
1271
+
1272
+ Scans statements for patterns that modify variables defined outside the loop.
1273
+ Returns a list of accumulator variable names that should be set as targets.
1274
+
1275
+ Supported patterns:
1276
+ 1. List append: results.append(value) -> "results"
1277
+ 2. Dict subscript: result[key] = value -> "result"
1278
+ 3. List/set update methods: results.extend(...), results.add(...) -> "results"
1279
+
1280
+ Note: Patterns like `results = results + [x]` and `count = count + 1` create
1281
+ new assignments which are tracked via in_scope_vars and don't need special
1282
+ detection here - they're handled by the regular assignment target logic.
1283
+ """
1284
+ accumulators: List[str] = []
1285
+ seen: set = set()
1286
+
1287
+ for stmt in stmts:
1288
+ var_name = self._extract_accumulator_from_stmt(stmt, in_scope_vars)
1289
+ if var_name and var_name not in seen:
1290
+ accumulators.append(var_name)
1291
+ seen.add(var_name)
1292
+
1293
+ # Check conditionals for accumulator targets in branch bodies
1294
+ if stmt.HasField("conditional"):
1295
+ cond = stmt.conditional
1296
+ branch_blocks: list[ir.Block] = []
1297
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1298
+ branch_blocks.append(cond.if_branch.block_body)
1299
+ for branch in cond.elif_branches:
1300
+ if branch.HasField("block_body"):
1301
+ branch_blocks.append(branch.block_body)
1302
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1303
+ branch_blocks.append(cond.else_branch.block_body)
1304
+
1305
+ for block in branch_blocks:
1306
+ for var in self._detect_accumulator_targets(
1307
+ list(block.statements), in_scope_vars
1308
+ ):
1309
+ if var not in seen:
1310
+ accumulators.append(var)
1311
+ seen.add(var)
1312
+
1313
+ return accumulators
1314
+
1315
+ def _extract_accumulator_from_stmt(
1316
+ self, stmt: ir.Statement, in_scope_vars: set
1317
+ ) -> Optional[str]:
1318
+ """Extract accumulator variable name from a single statement.
1319
+
1320
+ Returns the variable name if this statement modifies an out-of-scope variable,
1321
+ None otherwise.
1322
+ """
1323
+ # Pattern 1: Method calls like list.append(), dict.update(), set.add()
1324
+ if stmt.HasField("expr_stmt"):
1325
+ expr = stmt.expr_stmt.expr
1326
+ if expr.HasField("function_call"):
1327
+ fn_name = expr.function_call.name
1328
+ # Check for mutating method calls: x.append, x.extend, x.add, x.update, etc.
1329
+ mutating_methods = {
1330
+ ".append",
1331
+ ".extend",
1332
+ ".add",
1333
+ ".update",
1334
+ ".insert",
1335
+ ".pop",
1336
+ ".remove",
1337
+ ".clear",
1338
+ }
1339
+ for method in mutating_methods:
1340
+ if fn_name.endswith(method):
1341
+ var_name = fn_name[: len(fn_name) - len(method)]
1342
+ # Only return if it's an out-of-scope variable
1343
+ if var_name and var_name not in in_scope_vars:
1344
+ return var_name
1345
+
1346
+ # Pattern 2: Subscript assignment like dict[key] = value
1347
+ if stmt.HasField("assignment"):
1348
+ for target in stmt.assignment.targets:
1349
+ # Check if target is a subscript pattern (contains '[')
1350
+ if "[" in target:
1351
+ # Extract base variable name (before '[')
1352
+ var_name = target.split("[")[0]
1353
+ if var_name and var_name not in in_scope_vars:
1354
+ return var_name
1355
+
1356
+ # Pattern 3: Self-referential assignment like x = x + [y]
1357
+ # The target variable is used on the RHS, so it must come from outside.
1358
+ # Note: We don't check in_scope_vars here because the assignment itself
1359
+ # would have added the target to in_scope_vars, but it still needs its
1360
+ # previous value from outside the loop body.
1361
+ if stmt.HasField("assignment"):
1362
+ assign = stmt.assignment
1363
+ rhs_vars = self._collect_variables_from_expr(assign.value)
1364
+ for target in assign.targets:
1365
+ if target in rhs_vars:
1366
+ return target
1367
+
1368
+ return None
1369
+
1370
+ def _collect_variables_from_expr(self, expr: ir.Expr) -> set:
1371
+ """Recursively collect all variable names used in an expression."""
1372
+ vars_found: set = set()
1373
+
1374
+ if expr.HasField("variable"):
1375
+ vars_found.add(expr.variable.name)
1376
+ elif expr.HasField("binary_op"):
1377
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.left))
1378
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.right))
1379
+ elif expr.HasField("unary_op"):
1380
+ vars_found.update(self._collect_variables_from_expr(expr.unary_op.operand))
1381
+ elif expr.HasField("list"):
1382
+ for elem in expr.list.elements:
1383
+ vars_found.update(self._collect_variables_from_expr(elem))
1384
+ elif expr.HasField("dict"):
1385
+ for entry in expr.dict.entries:
1386
+ vars_found.update(self._collect_variables_from_expr(entry.key))
1387
+ vars_found.update(self._collect_variables_from_expr(entry.value))
1388
+ elif expr.HasField("index"):
1389
+ vars_found.update(self._collect_variables_from_expr(expr.index.value))
1390
+ vars_found.update(self._collect_variables_from_expr(expr.index.index))
1391
+ elif expr.HasField("dot"):
1392
+ vars_found.update(self._collect_variables_from_expr(expr.dot.object))
1393
+ elif expr.HasField("function_call"):
1394
+ for kwarg in expr.function_call.kwargs:
1395
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1396
+ elif expr.HasField("action_call"):
1397
+ for kwarg in expr.action_call.kwargs:
1398
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1399
+
1400
+ return vars_found
1401
+
1402
+ def _visit_if(self, node: ast.If) -> List[ir.Statement]:
1403
+ """Convert if statement to IR.
1404
+
1405
+ Normalizes patterns like:
1406
+ if await some_action(...):
1407
+ ...
1408
+ into:
1409
+ __if_cond_n__ = await some_action(...)
1410
+ if __if_cond_n__:
1411
+ ...
1412
+ """
1413
+
1414
+ def normalize_condition(test: ast.expr) -> tuple[List[ir.Statement], Optional[ir.Expr]]:
1415
+ action_call = self._extract_action_call(test)
1416
+ if action_call is None:
1417
+ return ([], _expr_to_ir(test))
1418
+
1419
+ if not isinstance(test, ast.Await):
1420
+ line = getattr(test, "lineno", None)
1421
+ col = getattr(test, "col_offset", None)
1422
+ raise UnsupportedPatternError(
1423
+ "Action calls inside boolean expressions are not supported in if conditions",
1424
+ "Assign the awaited action result to a variable, then use the variable in the if condition.",
1425
+ line=line,
1426
+ col=col,
1427
+ )
1428
+
1429
+ cond_var = self._ctx.next_implicit_fn_name(prefix="if_cond")
1430
+ assign_stmt = ir.Statement(span=_make_span(test))
1431
+ assign_stmt.assignment.CopyFrom(
1432
+ ir.Assignment(
1433
+ targets=[cond_var],
1434
+ value=ir.Expr(action_call=action_call, span=_make_span(test)),
1435
+ )
1436
+ )
1437
+ cond_expr = ir.Expr(variable=ir.Variable(name=cond_var), span=_make_span(test))
1438
+ return ([assign_stmt], cond_expr)
1439
+
1440
+ def visit_body(nodes: list[ast.stmt]) -> List[ir.Statement]:
1441
+ stmts: List[ir.Statement] = []
1442
+ for body_node in nodes:
1443
+ stmts.extend(self._visit_statement(body_node))
1444
+ return stmts
1445
+
1446
+ # Collect if/elif branches as (test_expr, body_nodes)
1447
+ branches: list[tuple[ast.expr, list[ast.stmt], ast.AST]] = [(node.test, node.body, node)]
1448
+ current = node
1449
+ while current.orelse and len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If):
1450
+ elif_node = current.orelse[0]
1451
+ branches.append((elif_node.test, elif_node.body, elif_node))
1452
+ current = elif_node
1453
+
1454
+ else_nodes = current.orelse
1455
+
1456
+ normalized: list[
1457
+ tuple[List[ir.Statement], Optional[ir.Expr], List[ir.Statement], ast.AST]
1458
+ ] = []
1459
+ for test_expr, body_nodes, span_node in branches:
1460
+ prefix, cond = normalize_condition(test_expr)
1461
+ normalized.append((prefix, cond, visit_body(body_nodes), span_node))
1462
+
1463
+ else_body = visit_body(else_nodes) if else_nodes else []
1464
+
1465
+ # If any non-first branch needs normalization, preserve Python semantics by nesting.
1466
+ requires_nested = any(prefix for prefix, _, _, _ in normalized[1:])
1467
+
1468
+ def build_conditional_stmt(
1469
+ condition: ir.Expr,
1470
+ then_body: List[ir.Statement],
1471
+ else_body_statements: List[ir.Statement],
1472
+ span_node: ast.AST,
1473
+ ) -> ir.Statement:
1474
+ conditional_stmt = ir.Statement(span=_make_span(span_node))
1475
+ if_branch = ir.IfBranch(
1476
+ condition=condition,
1477
+ block_body=ir.Block(statements=then_body, span=_make_span(span_node)),
1478
+ span=_make_span(span_node),
1479
+ )
1480
+ conditional = ir.Conditional(if_branch=if_branch)
1481
+ if else_body_statements:
1482
+ else_branch = ir.ElseBranch(
1483
+ block_body=ir.Block(
1484
+ statements=else_body_statements,
1485
+ span=_make_span(span_node),
1486
+ ),
1487
+ span=_make_span(span_node),
1488
+ )
1489
+ conditional.else_branch.CopyFrom(else_branch)
1490
+ conditional_stmt.conditional.CopyFrom(conditional)
1491
+ return conditional_stmt
1492
+
1493
+ if requires_nested:
1494
+ nested_else: List[ir.Statement] = else_body
1495
+ for prefix, cond, then_body, span_node in reversed(normalized):
1496
+ if cond is None:
1497
+ continue
1498
+ nested_if_stmt = build_conditional_stmt(
1499
+ condition=cond,
1500
+ then_body=then_body,
1501
+ else_body_statements=nested_else,
1502
+ span_node=span_node,
1503
+ )
1504
+ nested_else = [*prefix, nested_if_stmt]
1505
+ return nested_else
1506
+
1507
+ # Flat conditional with elif/else (original behavior), plus optional prefix for the if guard.
1508
+ if_prefix, if_condition, if_body, if_span_node = normalized[0]
1509
+ if if_condition is None:
1510
+ return []
1511
+
1512
+ conditional_stmt = ir.Statement(span=_make_span(if_span_node))
1513
+ if_branch = ir.IfBranch(
1514
+ condition=if_condition,
1515
+ block_body=ir.Block(statements=if_body, span=_make_span(if_span_node)),
1516
+ span=_make_span(if_span_node),
1517
+ )
1518
+ conditional = ir.Conditional(if_branch=if_branch)
1519
+
1520
+ for _, elif_condition, elif_body, elif_span_node in normalized[1:]:
1521
+ if elif_condition is None:
1522
+ continue
1523
+ elif_branch = ir.ElifBranch(
1524
+ condition=elif_condition,
1525
+ block_body=ir.Block(statements=elif_body, span=_make_span(elif_span_node)),
1526
+ span=_make_span(elif_span_node),
1527
+ )
1528
+ conditional.elif_branches.append(elif_branch)
1529
+
1530
+ if else_body:
1531
+ else_branch = ir.ElseBranch(
1532
+ block_body=ir.Block(statements=else_body, span=_make_span(if_span_node)),
1533
+ span=_make_span(if_span_node),
1534
+ )
1535
+ conditional.else_branch.CopyFrom(else_branch)
1536
+
1537
+ conditional_stmt.conditional.CopyFrom(conditional)
1538
+ return [*if_prefix, conditional_stmt]
1539
+
1540
+ def _collect_assigned_vars(self, stmts: List[ir.Statement]) -> set:
1541
+ """Collect all variable names assigned in a list of statements."""
1542
+ assigned = set()
1543
+ for stmt in stmts:
1544
+ if stmt.HasField("assignment"):
1545
+ assigned.update(stmt.assignment.targets)
1546
+ return assigned
1547
+
1548
+ def _collect_assigned_vars_in_order(self, stmts: List[ir.Statement]) -> list[str]:
1549
+ """Collect assigned variable names in statement order (deduplicated)."""
1550
+ assigned: list[str] = []
1551
+ seen: set[str] = set()
1552
+
1553
+ for stmt in stmts:
1554
+ if stmt.HasField("assignment"):
1555
+ for target in stmt.assignment.targets:
1556
+ if target not in seen:
1557
+ seen.add(target)
1558
+ assigned.append(target)
1559
+
1560
+ if stmt.HasField("conditional"):
1561
+ cond = stmt.conditional
1562
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1563
+ for target in self._collect_assigned_vars_in_order(
1564
+ list(cond.if_branch.block_body.statements)
1565
+ ):
1566
+ if target not in seen:
1567
+ seen.add(target)
1568
+ assigned.append(target)
1569
+ for elif_branch in cond.elif_branches:
1570
+ if elif_branch.HasField("block_body"):
1571
+ for target in self._collect_assigned_vars_in_order(
1572
+ list(elif_branch.block_body.statements)
1573
+ ):
1574
+ if target not in seen:
1575
+ seen.add(target)
1576
+ assigned.append(target)
1577
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1578
+ for target in self._collect_assigned_vars_in_order(
1579
+ list(cond.else_branch.block_body.statements)
1580
+ ):
1581
+ if target not in seen:
1582
+ seen.add(target)
1583
+ assigned.append(target)
1584
+
1585
+ if stmt.HasField("for_loop") and stmt.for_loop.HasField("block_body"):
1586
+ for target in self._collect_assigned_vars_in_order(
1587
+ list(stmt.for_loop.block_body.statements)
1588
+ ):
1589
+ if target not in seen:
1590
+ seen.add(target)
1591
+ assigned.append(target)
1592
+
1593
+ if stmt.HasField("try_except"):
1594
+ try_block = stmt.try_except.try_block
1595
+ if try_block.HasField("span"):
1596
+ for target in self._collect_assigned_vars_in_order(list(try_block.statements)):
1597
+ if target not in seen:
1598
+ seen.add(target)
1599
+ assigned.append(target)
1600
+ for handler in stmt.try_except.handlers:
1601
+ if handler.HasField("block_body"):
1602
+ for target in self._collect_assigned_vars_in_order(
1603
+ list(handler.block_body.statements)
1604
+ ):
1605
+ if target not in seen:
1606
+ seen.add(target)
1607
+ assigned.append(target)
1608
+
1609
+ return assigned
1610
+
1611
+ def _collect_variables_from_block(self, block: ir.Block) -> list[str]:
1612
+ return self._collect_variables_from_statements(list(block.statements))
1613
+
1614
+ def _collect_variables_from_statements(self, stmts: List[ir.Statement]) -> list[str]:
1615
+ """Collect variable references from statements in encounter order."""
1616
+ vars_found: list[str] = []
1617
+ seen: set[str] = set()
1618
+
1619
+ for stmt in stmts:
1620
+ if stmt.HasField("assignment") and stmt.assignment.HasField("value"):
1621
+ for var in self._collect_variables_from_expr(stmt.assignment.value):
1622
+ if var not in seen:
1623
+ seen.add(var)
1624
+ vars_found.append(var)
1625
+
1626
+ if stmt.HasField("return_stmt") and stmt.return_stmt.HasField("value"):
1627
+ for var in self._collect_variables_from_expr(stmt.return_stmt.value):
1628
+ if var not in seen:
1629
+ seen.add(var)
1630
+ vars_found.append(var)
1631
+
1632
+ if stmt.HasField("action_call"):
1633
+ expr = ir.Expr(action_call=stmt.action_call, span=stmt.span)
1634
+ for var in self._collect_variables_from_expr(expr):
1635
+ if var not in seen:
1636
+ seen.add(var)
1637
+ vars_found.append(var)
1638
+
1639
+ if stmt.HasField("expr_stmt"):
1640
+ for var in self._collect_variables_from_expr(stmt.expr_stmt.expr):
1641
+ if var not in seen:
1642
+ seen.add(var)
1643
+ vars_found.append(var)
1644
+
1645
+ if stmt.HasField("conditional"):
1646
+ cond = stmt.conditional
1647
+ if cond.HasField("if_branch"):
1648
+ if cond.if_branch.HasField("condition"):
1649
+ for var in self._collect_variables_from_expr(cond.if_branch.condition):
1650
+ if var not in seen:
1651
+ seen.add(var)
1652
+ vars_found.append(var)
1653
+ if cond.if_branch.HasField("block_body"):
1654
+ for var in self._collect_variables_from_block(cond.if_branch.block_body):
1655
+ if var not in seen:
1656
+ seen.add(var)
1657
+ vars_found.append(var)
1658
+ for elif_branch in cond.elif_branches:
1659
+ if elif_branch.HasField("condition"):
1660
+ for var in self._collect_variables_from_expr(elif_branch.condition):
1661
+ if var not in seen:
1662
+ seen.add(var)
1663
+ vars_found.append(var)
1664
+ if elif_branch.HasField("block_body"):
1665
+ for var in self._collect_variables_from_block(elif_branch.block_body):
1666
+ if var not in seen:
1667
+ seen.add(var)
1668
+ vars_found.append(var)
1669
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1670
+ for var in self._collect_variables_from_block(cond.else_branch.block_body):
1671
+ if var not in seen:
1672
+ seen.add(var)
1673
+ vars_found.append(var)
1674
+
1675
+ if stmt.HasField("for_loop"):
1676
+ fl = stmt.for_loop
1677
+ if fl.HasField("iterable"):
1678
+ for var in self._collect_variables_from_expr(fl.iterable):
1679
+ if var not in seen:
1680
+ seen.add(var)
1681
+ vars_found.append(var)
1682
+ if fl.HasField("block_body"):
1683
+ for var in self._collect_variables_from_block(fl.block_body):
1684
+ if var not in seen:
1685
+ seen.add(var)
1686
+ vars_found.append(var)
1687
+
1688
+ if stmt.HasField("try_except"):
1689
+ te = stmt.try_except
1690
+ if te.HasField("try_block"):
1691
+ for var in self._collect_variables_from_block(te.try_block):
1692
+ if var not in seen:
1693
+ seen.add(var)
1694
+ vars_found.append(var)
1695
+ for handler in te.handlers:
1696
+ if handler.HasField("block_body"):
1697
+ for var in self._collect_variables_from_block(handler.block_body):
1698
+ if var not in seen:
1699
+ seen.add(var)
1700
+ vars_found.append(var)
1701
+
1702
+ if stmt.HasField("parallel_block"):
1703
+ for call in stmt.parallel_block.calls:
1704
+ if call.HasField("action"):
1705
+ for kwarg in call.action.kwargs:
1706
+ for var in self._collect_variables_from_expr(kwarg.value):
1707
+ if var not in seen:
1708
+ seen.add(var)
1709
+ vars_found.append(var)
1710
+ elif call.HasField("function"):
1711
+ for kwarg in call.function.kwargs:
1712
+ for var in self._collect_variables_from_expr(kwarg.value):
1713
+ if var not in seen:
1714
+ seen.add(var)
1715
+ vars_found.append(var)
1716
+
1717
+ if stmt.HasField("spread_action"):
1718
+ spread = stmt.spread_action
1719
+ if spread.HasField("collection"):
1720
+ for var in self._collect_variables_from_expr(spread.collection):
1721
+ if var not in seen:
1722
+ seen.add(var)
1723
+ vars_found.append(var)
1724
+ if spread.HasField("action"):
1725
+ for kwarg in spread.action.kwargs:
1726
+ for var in self._collect_variables_from_expr(kwarg.value):
1727
+ if var not in seen:
1728
+ seen.add(var)
1729
+ vars_found.append(var)
1730
+
1731
+ return vars_found
1732
+
1733
+ def _visit_try(self, node: ast.Try) -> List[ir.Statement]:
1734
+ """Convert try/except to IR with full block bodies."""
1735
+ # Build try body statements (recursively transforms nested structures)
1736
+ try_body: List[ir.Statement] = []
1737
+ for body_node in node.body:
1738
+ stmts = self._visit_statement(body_node)
1739
+ try_body.extend(stmts)
1740
+
1741
+ # Build exception handlers (with wrapping if needed)
1742
+ handlers: List[ir.ExceptHandler] = []
1743
+ for handler in node.handlers:
1744
+ exception_types: List[str] = []
1745
+ if handler.type:
1746
+ if isinstance(handler.type, ast.Name):
1747
+ exception_types.append(handler.type.id)
1748
+ elif isinstance(handler.type, ast.Tuple):
1749
+ for elt in handler.type.elts:
1750
+ if isinstance(elt, ast.Name):
1751
+ exception_types.append(elt.id)
1752
+
1753
+ # Build handler body (recursively transforms nested structures)
1754
+ handler_body: List[ir.Statement] = []
1755
+ for handler_node in handler.body:
1756
+ stmts = self._visit_statement(handler_node)
1757
+ handler_body.extend(stmts)
1758
+
1759
+ except_handler = ir.ExceptHandler(
1760
+ exception_types=exception_types,
1761
+ block_body=ir.Block(statements=handler_body, span=_make_span(handler)),
1762
+ span=_make_span(handler),
1763
+ )
1764
+ handlers.append(except_handler)
1765
+
1766
+ # Build the try/except statement
1767
+ try_stmt = ir.Statement(span=_make_span(node))
1768
+ try_except = ir.TryExcept(
1769
+ handlers=handlers,
1770
+ try_block=ir.Block(statements=try_body, span=_make_span(node)),
1771
+ )
1772
+ try_stmt.try_except.CopyFrom(try_except)
1773
+
1774
+ return [try_stmt]
1775
+
1776
+ def _count_calls(self, stmts: List[ir.Statement]) -> int:
1777
+ """Count action calls and function calls in statements.
1778
+
1779
+ Both action calls and function calls (including synthetic functions)
1780
+ count toward the limit of one call per control flow body.
1781
+ """
1782
+ count = 0
1783
+ for stmt in stmts:
1784
+ if stmt.HasField("action_call"):
1785
+ count += 1
1786
+ elif stmt.HasField("assignment"):
1787
+ # Check if assignment value is an action call or function call
1788
+ if stmt.assignment.value.HasField("action_call"):
1789
+ count += 1
1790
+ elif stmt.assignment.value.HasField("function_call"):
1791
+ count += 1
1792
+ elif stmt.HasField("expr_stmt"):
1793
+ # Check if expression is a function call
1794
+ if stmt.expr_stmt.expr.HasField("function_call"):
1795
+ count += 1
1796
+ return count
1797
+
1798
+ def _wrap_body_as_function(
1799
+ self,
1800
+ body: List[ir.Statement],
1801
+ prefix: str,
1802
+ node: ast.AST,
1803
+ inputs: Optional[List[str]] = None,
1804
+ modified_vars: Optional[List[str]] = None,
1805
+ ) -> List[ir.Statement]:
1806
+ """Wrap a body with multiple calls into a synthetic function.
1807
+
1808
+ Args:
1809
+ body: The statements to wrap
1810
+ prefix: Name prefix for the synthetic function
1811
+ node: AST node for span information
1812
+ inputs: Variables to pass as inputs (e.g., loop variables)
1813
+ modified_vars: Out-of-scope variables modified in the body.
1814
+ These are added as inputs AND returned as outputs,
1815
+ enabling functional transformation of external state.
1816
+
1817
+ Returns a list containing a single function call statement (or assignment
1818
+ if modified_vars are present).
1819
+ """
1820
+ fn_name = self._ctx.next_implicit_fn_name(prefix)
1821
+ fn_inputs = list(inputs or [])
1822
+
1823
+ # Add modified variables as inputs (they need to be passed in)
1824
+ modified_vars = modified_vars or []
1825
+ for var in modified_vars:
1826
+ if var not in fn_inputs:
1827
+ fn_inputs.append(var)
1828
+
1829
+ # If there are modified variables, add a return statement for them
1830
+ wrapped_body = list(body)
1831
+ if modified_vars:
1832
+ # Create return statement: return (var1, var2, ...) or return var1
1833
+ if len(modified_vars) == 1:
1834
+ return_expr = ir.Expr(
1835
+ variable=ir.Variable(name=modified_vars[0]),
1836
+ span=_make_span(node),
1837
+ )
1838
+ else:
1839
+ # Return as list (tuples are represented as lists in IR)
1840
+ return_expr = ir.Expr(
1841
+ list=ir.ListExpr(
1842
+ elements=[ir.Expr(variable=ir.Variable(name=var)) for var in modified_vars]
1843
+ ),
1844
+ span=_make_span(node),
1845
+ )
1846
+ return_stmt = ir.Statement(span=_make_span(node))
1847
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=return_expr))
1848
+ wrapped_body.append(return_stmt)
1849
+
1850
+ # Create the synthetic function
1851
+ implicit_fn = ir.FunctionDef(
1852
+ name=fn_name,
1853
+ io=ir.IoDecl(inputs=fn_inputs, outputs=modified_vars),
1854
+ body=ir.Block(statements=wrapped_body),
1855
+ span=_make_span(node),
1856
+ )
1857
+ self._ctx.implicit_functions.append(implicit_fn)
1858
+
1859
+ # Create a function call expression
1860
+ kwargs = [
1861
+ ir.Kwarg(name=var, value=ir.Expr(variable=ir.Variable(name=var))) for var in fn_inputs
1862
+ ]
1863
+ fn_call_expr = ir.Expr(
1864
+ function_call=ir.FunctionCall(name=fn_name, kwargs=kwargs),
1865
+ span=_make_span(node),
1866
+ )
1867
+
1868
+ # If there are modified variables, create an assignment statement
1869
+ # so the returned values are assigned back to the variables
1870
+ call_stmt = ir.Statement(span=_make_span(node))
1871
+ if modified_vars:
1872
+ # Create assignment: var1, var2 = fn(...) or var1 = fn(...)
1873
+ assign = ir.Assignment(value=fn_call_expr)
1874
+ assign.targets.extend(modified_vars)
1875
+ call_stmt.assignment.CopyFrom(assign)
1876
+ else:
1877
+ call_stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=fn_call_expr))
1878
+
1879
+ return [call_stmt]
1880
+
1881
+ def _visit_return(self, node: ast.Return) -> List[ir.Statement]:
1882
+ """Convert return statement to IR.
1883
+
1884
+ Return statements should only contain variables or literals, not action calls.
1885
+ If the return contains an action call, we normalize it:
1886
+ return await action()
1887
+ becomes:
1888
+ _return_tmp = await action()
1889
+ return _return_tmp
1890
+
1891
+ Constructor calls (like return MyModel(...)) are not supported and will
1892
+ raise an error with a recommendation to use an @action instead.
1893
+ """
1894
+ if node.value:
1895
+ # Check for constructor calls in return (e.g., return MyModel(...))
1896
+ self._check_constructor_in_return(node.value)
1897
+
1898
+ # Check if returning an action call - normalize to assignment + return
1899
+ action_call = self._extract_action_call(node.value)
1900
+ if action_call:
1901
+ # Create a temporary variable for the action result
1902
+ tmp_var = "_return_tmp"
1903
+
1904
+ # Create assignment: _return_tmp = await action()
1905
+ assign_stmt = ir.Statement(span=_make_span(node))
1906
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1907
+ assign = ir.Assignment(targets=[tmp_var], value=value)
1908
+ assign_stmt.assignment.CopyFrom(assign)
1909
+
1910
+ # Create return: return _return_tmp
1911
+ return_stmt = ir.Statement(span=_make_span(node))
1912
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
1913
+ ret = ir.ReturnStmt(value=var_expr)
1914
+ return_stmt.return_stmt.CopyFrom(ret)
1915
+
1916
+ return [assign_stmt, return_stmt]
1917
+
1918
+ # Normalize return of function calls into assignment + return
1919
+ expr = _expr_to_ir(node.value)
1920
+ if expr and expr.HasField("function_call"):
1921
+ tmp_var = self._ctx.next_implicit_fn_name(prefix="return_tmp")
1922
+
1923
+ assign_stmt = ir.Statement(span=_make_span(node))
1924
+ assign_stmt.assignment.CopyFrom(ir.Assignment(targets=[tmp_var], value=expr))
1925
+
1926
+ return_stmt = ir.Statement(span=_make_span(node))
1927
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
1928
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=var_expr))
1929
+ return [assign_stmt, return_stmt]
1930
+
1931
+ # Regular return with expression (variable, literal, etc.)
1932
+ if expr:
1933
+ stmt = ir.Statement(span=_make_span(node))
1934
+ return_stmt = ir.ReturnStmt(value=expr)
1935
+ stmt.return_stmt.CopyFrom(return_stmt)
1936
+ return [stmt]
1937
+
1938
+ # Return with no value
1939
+ stmt = ir.Statement(span=_make_span(node))
1940
+ stmt.return_stmt.CopyFrom(ir.ReturnStmt())
1941
+ return [stmt]
1942
+
1943
+ def _visit_aug_assign(self, node: ast.AugAssign) -> List[ir.Statement]:
1944
+ """Convert augmented assignment (+=, -=, etc.) to IR."""
1945
+ # For now, we can represent this as a regular assignment with binary op
1946
+ # target op= value -> target = target op value
1947
+ stmt = ir.Statement(span=_make_span(node))
1948
+
1949
+ targets: List[str] = []
1950
+ if isinstance(node.target, ast.Name):
1951
+ targets.append(node.target.id)
1952
+
1953
+ left = _expr_to_ir(node.target)
1954
+ right = _expr_to_ir(node.value)
1955
+ if right and right.HasField("function_call"):
1956
+ tmp_var = self._ctx.next_implicit_fn_name(prefix="aug_tmp")
1957
+
1958
+ assign_tmp = ir.Statement(span=_make_span(node))
1959
+ assign_tmp.assignment.CopyFrom(
1960
+ ir.Assignment(
1961
+ targets=[tmp_var],
1962
+ value=ir.Expr(function_call=right.function_call, span=_make_span(node)),
1963
+ )
1964
+ )
1965
+
1966
+ if left:
1967
+ op = _bin_op_to_ir(node.op)
1968
+ if op:
1969
+ binary = ir.BinaryOp(
1970
+ left=left,
1971
+ op=op,
1972
+ right=ir.Expr(variable=ir.Variable(name=tmp_var)),
1973
+ )
1974
+ value = ir.Expr(binary_op=binary)
1975
+ assign = ir.Assignment(targets=targets, value=value)
1976
+ stmt.assignment.CopyFrom(assign)
1977
+ return [assign_tmp, stmt]
1978
+ return [assign_tmp]
1979
+
1980
+ if left and right:
1981
+ op = _bin_op_to_ir(node.op)
1982
+ if op:
1983
+ binary = ir.BinaryOp(left=left, op=op, right=right)
1984
+ value = ir.Expr(binary_op=binary)
1985
+ assign = ir.Assignment(targets=targets, value=value)
1986
+ stmt.assignment.CopyFrom(assign)
1987
+ return [stmt]
1988
+
1989
+ return []
1990
+
1991
+ def _collect_function_inputs(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> List[str]:
1992
+ """Collect workflow inputs from function parameters, including kw-only args."""
1993
+ args: List[str] = []
1994
+ seen: set[str] = set()
1995
+
1996
+ ordered_args = list(node.args.posonlyargs) + list(node.args.args)
1997
+ if ordered_args and ordered_args[0].arg == "self":
1998
+ ordered_args = ordered_args[1:]
1999
+
2000
+ for arg in ordered_args:
2001
+ if arg.arg not in seen:
2002
+ args.append(arg.arg)
2003
+ seen.add(arg.arg)
2004
+
2005
+ if node.args.vararg and node.args.vararg.arg not in seen:
2006
+ args.append(node.args.vararg.arg)
2007
+ seen.add(node.args.vararg.arg)
2008
+
2009
+ for arg in node.args.kwonlyargs:
2010
+ if arg.arg not in seen:
2011
+ args.append(arg.arg)
2012
+ seen.add(arg.arg)
2013
+
2014
+ if node.args.kwarg and node.args.kwarg.arg not in seen:
2015
+ args.append(node.args.kwarg.arg)
2016
+ seen.add(node.args.kwarg.arg)
2017
+
2018
+ return args
2019
+
2020
+ def _check_constructor_in_return(self, node: ast.expr) -> None:
2021
+ """Check for constructor calls in return statements.
2022
+
2023
+ Raises UnsupportedPatternError if the return value is a class instantiation
2024
+ like: return MyModel(field=value)
2025
+
2026
+ This is not supported because the workflow IR cannot serialize arbitrary
2027
+ object instantiation. Users should use an @action to create objects.
2028
+ """
2029
+ # Skip if it's an await (action call) - those are fine
2030
+ if isinstance(node, ast.Await):
2031
+ return
2032
+
2033
+ # Check for direct Call that looks like a constructor
2034
+ if isinstance(node, ast.Call):
2035
+ func_name = self._get_constructor_name(node.func)
2036
+ if func_name and self._looks_like_constructor(func_name, node):
2037
+ line = getattr(node, "lineno", None)
2038
+ col = getattr(node, "col_offset", None)
2039
+ raise UnsupportedPatternError(
2040
+ f"Returning constructor call '{func_name}(...)' is not supported",
2041
+ RECOMMENDATIONS["constructor_return"],
2042
+ line=line,
2043
+ col=col,
2044
+ )
2045
+
2046
+ def _check_constructor_in_assignment(self, node: ast.expr) -> None:
2047
+ """Check for constructor calls in assignments.
2048
+
2049
+ Raises UnsupportedPatternError if the assignment value is a class instantiation
2050
+ like: result = MyModel(field=value)
2051
+
2052
+ This is not supported because the workflow IR cannot serialize arbitrary
2053
+ object instantiation. Users should use an @action to create objects.
2054
+ """
2055
+ # Skip if it's an await (action call) - those are fine
2056
+ if isinstance(node, ast.Await):
2057
+ return
2058
+
2059
+ # Check for direct Call that looks like a constructor
2060
+ if isinstance(node, ast.Call):
2061
+ func_name = self._get_constructor_name(node.func)
2062
+ if func_name and self._looks_like_constructor(func_name, node):
2063
+ line = getattr(node, "lineno", None)
2064
+ col = getattr(node, "col_offset", None)
2065
+ raise UnsupportedPatternError(
2066
+ f"Assigning constructor call '{func_name}(...)' is not supported",
2067
+ RECOMMENDATIONS["constructor_assignment"],
2068
+ line=line,
2069
+ col=col,
2070
+ )
2071
+
2072
+ def _get_constructor_name(self, func: ast.expr) -> Optional[str]:
2073
+ """Get the name from a function expression if it looks like a constructor."""
2074
+ if isinstance(func, ast.Name):
2075
+ return func.id
2076
+ elif isinstance(func, ast.Attribute):
2077
+ return func.attr
2078
+ return None
2079
+
2080
+ def _looks_like_constructor(self, func_name: str, call: ast.Call) -> bool:
2081
+ """Check if a function call looks like a class constructor.
2082
+
2083
+ A constructor is identified by:
2084
+ 1. Name starts with uppercase (PEP8 convention for classes)
2085
+ 2. It's not a known action
2086
+ 3. It's not a known builtin like String operations
2087
+ 4. It's not a known Pydantic model or dataclass (those are allowed)
2088
+
2089
+ This is a heuristic - we can't perfectly distinguish constructors
2090
+ from functions without full type information.
2091
+ """
2092
+ # Check if first letter is uppercase (class naming convention)
2093
+ if not func_name or not func_name[0].isupper():
2094
+ return False
2095
+
2096
+ # If it's a known action, it's not a constructor
2097
+ if func_name in self._action_defs:
2098
+ return False
2099
+
2100
+ # If it's a known Pydantic model or dataclass, allow it
2101
+ # (it will be converted to a dict expression)
2102
+ if func_name in self._model_defs:
2103
+ return False
2104
+
2105
+ # Common builtins that start with uppercase but aren't constructors
2106
+ # (these are rarely used in workflow code but let's be safe)
2107
+ builtin_exceptions = {"True", "False", "None", "Ellipsis"}
2108
+ if func_name in builtin_exceptions:
2109
+ return False
2110
+
2111
+ return True
2112
+
2113
+ def _is_model_constructor(self, node: ast.expr) -> Optional[str]:
2114
+ """Check if an expression is a Pydantic model or dataclass constructor call.
2115
+
2116
+ Returns the model name if it is, None otherwise.
2117
+ """
2118
+ if not isinstance(node, ast.Call):
2119
+ return None
2120
+
2121
+ func_name = self._get_constructor_name(node.func)
2122
+ if func_name and func_name in self._model_defs:
2123
+ return func_name
2124
+
2125
+ return None
2126
+
2127
+ def _convert_model_constructor_to_dict(self, node: ast.Call, model_name: str) -> ir.Expr:
2128
+ """Convert a Pydantic model or dataclass constructor call to a dict expression.
2129
+
2130
+ For example:
2131
+ MyModel(field1=value1, field2=value2)
2132
+ becomes:
2133
+ {"field1": value1, "field2": value2}
2134
+
2135
+ Default values from the model definition are included for fields not
2136
+ explicitly provided in the constructor call.
2137
+ """
2138
+ model_def = self._model_defs[model_name]
2139
+ entries: List[ir.DictEntry] = []
2140
+
2141
+ # Track which fields were explicitly provided
2142
+ provided_fields: Set[str] = set()
2143
+
2144
+ # First, add all explicitly provided kwargs
2145
+ for kw in node.keywords:
2146
+ if kw.arg is None:
2147
+ # **kwargs expansion - not supported
2148
+ line = getattr(node, "lineno", None)
2149
+ col = getattr(node, "col_offset", None)
2150
+ raise UnsupportedPatternError(
2151
+ f"Model constructor '{model_name}' with **kwargs is not supported",
2152
+ "Use explicit keyword arguments instead of **kwargs.",
2153
+ line=line,
2154
+ col=col,
2155
+ )
2156
+
2157
+ provided_fields.add(kw.arg)
2158
+ key_expr = ir.Expr()
2159
+ key_literal = ir.Literal()
2160
+ key_literal.string_value = kw.arg
2161
+ key_expr.literal.CopyFrom(key_literal)
2162
+
2163
+ value_expr = _expr_to_ir(kw.value)
2164
+ if value_expr is None:
2165
+ # If we can't convert the value, we need to raise an error
2166
+ line = getattr(node, "lineno", None)
2167
+ col = getattr(node, "col_offset", None)
2168
+ raise UnsupportedPatternError(
2169
+ f"Cannot convert value for field '{kw.arg}' in '{model_name}'",
2170
+ "Use simpler expressions (literals, variables, dicts, lists).",
2171
+ line=line,
2172
+ col=col,
2173
+ )
2174
+
2175
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2176
+
2177
+ # Handle positional arguments (dataclasses support this)
2178
+ if node.args:
2179
+ # For dataclasses, positional args map to fields in order
2180
+ field_names = list(model_def.fields.keys())
2181
+ for i, arg in enumerate(node.args):
2182
+ if i >= len(field_names):
2183
+ line = getattr(node, "lineno", None)
2184
+ col = getattr(node, "col_offset", None)
2185
+ raise UnsupportedPatternError(
2186
+ f"Too many positional arguments for '{model_name}'",
2187
+ "Use keyword arguments for clarity.",
2188
+ line=line,
2189
+ col=col,
2190
+ )
2191
+
2192
+ field_name = field_names[i]
2193
+ provided_fields.add(field_name)
2194
+
2195
+ key_expr = ir.Expr()
2196
+ key_literal = ir.Literal()
2197
+ key_literal.string_value = field_name
2198
+ key_expr.literal.CopyFrom(key_literal)
2199
+
2200
+ value_expr = _expr_to_ir(arg)
2201
+ if value_expr is None:
2202
+ line = getattr(node, "lineno", None)
2203
+ col = getattr(node, "col_offset", None)
2204
+ raise UnsupportedPatternError(
2205
+ f"Cannot convert positional argument for field '{field_name}' in '{model_name}'",
2206
+ "Use simpler expressions (literals, variables, dicts, lists).",
2207
+ line=line,
2208
+ col=col,
2209
+ )
2210
+
2211
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2212
+
2213
+ # Add default values for fields not explicitly provided
2214
+ for field_name, field_def in model_def.fields.items():
2215
+ if field_name in provided_fields:
2216
+ continue
2217
+
2218
+ if field_def.has_default:
2219
+ key_expr = ir.Expr()
2220
+ key_literal = ir.Literal()
2221
+ key_literal.string_value = field_name
2222
+ key_expr.literal.CopyFrom(key_literal)
2223
+
2224
+ # Convert the default value to an IR literal
2225
+ default_literal = _constant_to_literal(field_def.default_value)
2226
+ if default_literal is None:
2227
+ # Can't serialize this default - skip it
2228
+ # (it's probably a complex object like a list factory)
2229
+ continue
2230
+
2231
+ value_expr = ir.Expr()
2232
+ value_expr.literal.CopyFrom(default_literal)
2233
+
2234
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2235
+
2236
+ result = ir.Expr(span=_make_span(node))
2237
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2238
+ return result
2239
+
2240
+ def _check_non_action_await(self, node: ast.Await) -> None:
2241
+ """Check if an await is for a non-action function.
2242
+
2243
+ Note: We can only reliably detect non-action awaits for functions defined
2244
+ in the same module. Actions imported from other modules will pass through
2245
+ and may fail at runtime if they're not actually actions.
2246
+
2247
+ For now, we only check against common builtins and known non-action patterns.
2248
+ A runtime check will catch functions that aren't registered actions.
2249
+ """
2250
+ awaited = node.value
2251
+ if not isinstance(awaited, ast.Call):
2252
+ return
2253
+
2254
+ # Skip special cases that are handled elsewhere
2255
+ if self._is_run_action_call(awaited):
2256
+ return
2257
+ if self._is_asyncio_sleep_call(awaited):
2258
+ return
2259
+ if self._is_asyncio_gather_call(awaited):
2260
+ return
2261
+
2262
+ # Get the function name
2263
+ func_name = None
2264
+ if isinstance(awaited.func, ast.Name):
2265
+ func_name = awaited.func.id
2266
+ elif isinstance(awaited.func, ast.Attribute):
2267
+ func_name = awaited.func.attr
2268
+
2269
+ if not func_name:
2270
+ return
2271
+
2272
+ # Only raise error for functions defined in THIS module that we know
2273
+ # are NOT actions (i.e., async functions without @action decorator)
2274
+ # We can't reliably detect imported non-actions without full type info.
2275
+ #
2276
+ # The check works by looking at _module_functions which contains
2277
+ # functions defined in the same module as the workflow.
2278
+ if func_name in getattr(self, "_module_functions", set()):
2279
+ if func_name not in self._action_defs:
2280
+ line = getattr(node, "lineno", None)
2281
+ col = getattr(node, "col_offset", None)
2282
+ raise UnsupportedPatternError(
2283
+ f"Awaiting non-action function '{func_name}()' is not supported",
2284
+ RECOMMENDATIONS["non_action_call"],
2285
+ line=line,
2286
+ col=col,
2287
+ )
2288
+
2289
+ def _check_sync_function_call(self, node: ast.Call) -> None:
2290
+ """Check for synchronous function calls that should be in actions.
2291
+
2292
+ Common patterns like len(), str(), etc. are not supported in workflow code.
2293
+ """
2294
+ func_name = None
2295
+ if isinstance(node.func, ast.Name):
2296
+ func_name = node.func.id
2297
+ elif isinstance(node.func, ast.Attribute):
2298
+ # Method calls on objects - check the method name
2299
+ func_name = node.func.attr
2300
+
2301
+ if not func_name:
2302
+ return
2303
+
2304
+ # Builtins that users commonly try to use
2305
+ common_builtins = {
2306
+ "len",
2307
+ "str",
2308
+ "int",
2309
+ "float",
2310
+ "bool",
2311
+ "list",
2312
+ "dict",
2313
+ "set",
2314
+ "tuple",
2315
+ "sum",
2316
+ "min",
2317
+ "max",
2318
+ "sorted",
2319
+ "reversed",
2320
+ "enumerate",
2321
+ "zip",
2322
+ "map",
2323
+ "filter",
2324
+ "range",
2325
+ "abs",
2326
+ "round",
2327
+ "print",
2328
+ "type",
2329
+ "isinstance",
2330
+ "hasattr",
2331
+ "getattr",
2332
+ "setattr",
2333
+ "open",
2334
+ "format",
2335
+ }
2336
+
2337
+ if func_name in common_builtins:
2338
+ line = getattr(node, "lineno", None)
2339
+ col = getattr(node, "col_offset", None)
2340
+ raise UnsupportedPatternError(
2341
+ f"Calling built-in function '{func_name}()' directly is not supported",
2342
+ RECOMMENDATIONS["builtin_call"],
2343
+ line=line,
2344
+ col=col,
2345
+ )
2346
+
2347
+ def _extract_action_call(self, node: ast.expr) -> Optional[ir.ActionCall]:
2348
+ """Extract an action call from an expression if present.
2349
+
2350
+ Also validates that awaited calls are actually @action decorated functions.
2351
+ Raises UnsupportedPatternError if awaiting a non-action function.
2352
+ """
2353
+ if not isinstance(node, ast.Await):
2354
+ return None
2355
+
2356
+ awaited = node.value
2357
+ # Handle self.run_action(...) wrapper
2358
+ if isinstance(awaited, ast.Call):
2359
+ if self._is_run_action_call(awaited):
2360
+ # Extract the actual action call from run_action
2361
+ if awaited.args:
2362
+ action_call = self._extract_action_call_from_awaitable(awaited.args[0])
2363
+ if action_call:
2364
+ # Extract policies from run_action kwargs (retry, timeout)
2365
+ self._extract_policies_from_run_action(awaited, action_call)
2366
+ return action_call
2367
+ # Check for asyncio.sleep() - convert to @sleep action
2368
+ if self._is_asyncio_sleep_call(awaited):
2369
+ return self._convert_asyncio_sleep_to_action(awaited)
2370
+ # Try to extract as action call
2371
+ action_call = self._extract_action_call_from_call(awaited)
2372
+ if action_call:
2373
+ return action_call
2374
+
2375
+ # If we get here, it's an await of a non-action function
2376
+ self._check_non_action_await(node)
2377
+ return None
2378
+
2379
+ return None
2380
+
2381
+ def _is_run_action_call(self, node: ast.Call) -> bool:
2382
+ """Check if this is a self.run_action(...) call."""
2383
+ if isinstance(node.func, ast.Attribute):
2384
+ return node.func.attr == "run_action"
2385
+ return False
2386
+
2387
+ def _extract_policies_from_run_action(
2388
+ self, run_action_call: ast.Call, action_call: ir.ActionCall
2389
+ ) -> None:
2390
+ """Extract retry and timeout policies from run_action kwargs.
2391
+
2392
+ Parses patterns like:
2393
+ - self.run_action(action(), retry=RetryPolicy(attempts=3))
2394
+ - self.run_action(action(), timeout=timedelta(seconds=30))
2395
+ - self.run_action(action(), timeout=60)
2396
+ """
2397
+ for kw in run_action_call.keywords:
2398
+ if kw.arg == "retry":
2399
+ retry_policy = self._parse_retry_policy(kw.value)
2400
+ if retry_policy:
2401
+ policy_bracket = ir.PolicyBracket()
2402
+ policy_bracket.retry.CopyFrom(retry_policy)
2403
+ action_call.policies.append(policy_bracket)
2404
+ elif kw.arg == "timeout":
2405
+ timeout_policy = self._parse_timeout_policy(kw.value)
2406
+ if timeout_policy:
2407
+ policy_bracket = ir.PolicyBracket()
2408
+ policy_bracket.timeout.CopyFrom(timeout_policy)
2409
+ action_call.policies.append(policy_bracket)
2410
+
2411
+ def _parse_retry_policy(self, node: ast.expr) -> Optional[ir.RetryPolicy]:
2412
+ """Parse a RetryPolicy(...) call into IR.
2413
+
2414
+ Supports:
2415
+ - RetryPolicy(attempts=3)
2416
+ - RetryPolicy(attempts=3, exception_types=["ValueError"])
2417
+ - RetryPolicy(attempts=3, backoff_seconds=5)
2418
+ """
2419
+ if not isinstance(node, ast.Call):
2420
+ return None
2421
+
2422
+ # Check if it's a RetryPolicy call
2423
+ func_name = None
2424
+ if isinstance(node.func, ast.Name):
2425
+ func_name = node.func.id
2426
+ elif isinstance(node.func, ast.Attribute):
2427
+ func_name = node.func.attr
2428
+
2429
+ if func_name != "RetryPolicy":
2430
+ return None
2431
+
2432
+ policy = ir.RetryPolicy()
2433
+
2434
+ for kw in node.keywords:
2435
+ if kw.arg == "attempts" and isinstance(kw.value, ast.Constant):
2436
+ policy.max_retries = kw.value.value
2437
+ elif kw.arg == "exception_types" and isinstance(kw.value, ast.List):
2438
+ for elt in kw.value.elts:
2439
+ if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
2440
+ policy.exception_types.append(elt.value)
2441
+ elif kw.arg == "backoff_seconds" and isinstance(kw.value, ast.Constant):
2442
+ policy.backoff.seconds = int(kw.value.value)
2443
+
2444
+ return policy
2445
+
2446
+ def _parse_timeout_policy(self, node: ast.expr) -> Optional[ir.TimeoutPolicy]:
2447
+ """Parse a timeout value into IR.
2448
+
2449
+ Supports:
2450
+ - timeout=60 (int seconds)
2451
+ - timeout=30.5 (float seconds)
2452
+ - timeout=timedelta(seconds=30)
2453
+ - timeout=timedelta(minutes=2)
2454
+ """
2455
+ policy = ir.TimeoutPolicy()
2456
+
2457
+ # Direct numeric value (seconds)
2458
+ if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
2459
+ policy.timeout.seconds = int(node.value)
2460
+ return policy
2461
+
2462
+ # timedelta(...) call
2463
+ if isinstance(node, ast.Call):
2464
+ func_name = None
2465
+ if isinstance(node.func, ast.Name):
2466
+ func_name = node.func.id
2467
+ elif isinstance(node.func, ast.Attribute):
2468
+ func_name = node.func.attr
2469
+
2470
+ if func_name == "timedelta":
2471
+ total_seconds = 0
2472
+ for kw in node.keywords:
2473
+ if isinstance(kw.value, ast.Constant):
2474
+ val = kw.value.value
2475
+ if kw.arg == "seconds":
2476
+ total_seconds += int(val)
2477
+ elif kw.arg == "minutes":
2478
+ total_seconds += int(val) * 60
2479
+ elif kw.arg == "hours":
2480
+ total_seconds += int(val) * 3600
2481
+ elif kw.arg == "days":
2482
+ total_seconds += int(val) * 86400
2483
+ policy.timeout.seconds = total_seconds
2484
+ return policy
2485
+
2486
+ return None
2487
+
2488
+ def _is_asyncio_sleep_call(self, node: ast.Call) -> bool:
2489
+ """Check if this is an asyncio.sleep(...) call.
2490
+
2491
+ Supports both patterns:
2492
+ - import asyncio; asyncio.sleep(1)
2493
+ - from asyncio import sleep; sleep(1)
2494
+ - from asyncio import sleep as s; s(1)
2495
+ """
2496
+ if isinstance(node.func, ast.Attribute):
2497
+ # asyncio.sleep(...) pattern
2498
+ if node.func.attr == "sleep" and isinstance(node.func.value, ast.Name):
2499
+ return node.func.value.id == "asyncio"
2500
+ elif isinstance(node.func, ast.Name):
2501
+ # sleep(...) pattern - check if it's imported from asyncio
2502
+ func_name = node.func.id
2503
+ if func_name in self._imported_names:
2504
+ imported = self._imported_names[func_name]
2505
+ return imported.module == "asyncio" and imported.original_name == "sleep"
2506
+ return False
2507
+
2508
+ def _convert_asyncio_sleep_to_action(self, node: ast.Call) -> ir.ActionCall:
2509
+ """Convert asyncio.sleep(duration) to @sleep(duration=X) action call.
2510
+
2511
+ This creates a built-in sleep action that the scheduler handles as a
2512
+ durable sleep - stored in the DB with a future scheduled_at time.
2513
+ """
2514
+ action_call = ir.ActionCall(action_name="sleep")
2515
+
2516
+ # Extract duration argument (positional or keyword)
2517
+ if node.args:
2518
+ # asyncio.sleep(1) - positional
2519
+ expr = _expr_to_ir(node.args[0])
2520
+ if expr:
2521
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2522
+ elif node.keywords:
2523
+ # asyncio.sleep(seconds=1) - keyword (less common)
2524
+ for kw in node.keywords:
2525
+ if kw.arg in ("seconds", "delay", "duration"):
2526
+ expr = _expr_to_ir(kw.value)
2527
+ if expr:
2528
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2529
+ break
2530
+
2531
+ return action_call
2532
+
2533
+ def _is_asyncio_gather_call(self, node: ast.Call) -> bool:
2534
+ """Check if this is an asyncio.gather(...) call.
2535
+
2536
+ Supports both patterns:
2537
+ - import asyncio; asyncio.gather(a(), b())
2538
+ - from asyncio import gather; gather(a(), b())
2539
+ - from asyncio import gather as g; g(a(), b())
2540
+ """
2541
+ if isinstance(node.func, ast.Attribute):
2542
+ # asyncio.gather(...) pattern
2543
+ if node.func.attr == "gather" and isinstance(node.func.value, ast.Name):
2544
+ return node.func.value.id == "asyncio"
2545
+ elif isinstance(node.func, ast.Name):
2546
+ # gather(...) pattern - check if it's imported from asyncio
2547
+ func_name = node.func.id
2548
+ if func_name in self._imported_names:
2549
+ imported = self._imported_names[func_name]
2550
+ return imported.module == "asyncio" and imported.original_name == "gather"
2551
+ return False
2552
+
2553
+ def _convert_asyncio_gather(
2554
+ self, node: ast.Call
2555
+ ) -> Optional[Union[ir.ParallelExpr, ir.SpreadExpr]]:
2556
+ """Convert asyncio.gather(...) to ParallelExpr or SpreadExpr IR.
2557
+
2558
+ Handles two patterns:
2559
+ 1. Static gather: asyncio.gather(a(), b(), c()) -> ParallelExpr
2560
+ 2. Spread gather: asyncio.gather(*[action(x) for x in items]) -> SpreadExpr
2561
+
2562
+ Args:
2563
+ node: The asyncio.gather() Call node
2564
+
2565
+ Returns:
2566
+ A ParallelExpr, SpreadExpr, or None if conversion fails.
2567
+ """
2568
+ # Check for starred expressions - spread pattern
2569
+ if len(node.args) == 1 and isinstance(node.args[0], ast.Starred):
2570
+ starred = node.args[0]
2571
+ # Only list comprehensions are supported for spread
2572
+ if isinstance(starred.value, ast.ListComp):
2573
+ return self._convert_listcomp_to_spread_expr(starred.value)
2574
+ else:
2575
+ # Spreading a variable or other expression is not supported
2576
+ line = getattr(node, "lineno", None)
2577
+ col = getattr(node, "col_offset", None)
2578
+ if isinstance(starred.value, ast.Name):
2579
+ var_name = starred.value.id
2580
+ raise UnsupportedPatternError(
2581
+ f"Spreading variable '{var_name}' in asyncio.gather() is not supported",
2582
+ RECOMMENDATIONS["gather_variable_spread"],
2583
+ line=line,
2584
+ col=col,
2585
+ )
2586
+ else:
2587
+ raise UnsupportedPatternError(
2588
+ "Spreading non-list-comprehension expressions in asyncio.gather() is not supported",
2589
+ RECOMMENDATIONS["gather_variable_spread"],
2590
+ line=line,
2591
+ col=col,
2592
+ )
2593
+
2594
+ # Standard case: gather(a(), b(), c()) -> ParallelExpr
2595
+ parallel = ir.ParallelExpr()
2596
+
2597
+ # Each argument to gather() should be an action call
2598
+ for arg in node.args:
2599
+ call = self._convert_gather_arg_to_call(arg)
2600
+ if call:
2601
+ parallel.calls.append(call)
2602
+
2603
+ # Only return if we have calls
2604
+ if not parallel.calls:
2605
+ return None
2606
+
2607
+ return parallel
2608
+
2609
+ def _convert_listcomp_to_spread_expr(self, listcomp: ast.ListComp) -> Optional[ir.SpreadExpr]:
2610
+ """Convert a list comprehension to SpreadExpr IR.
2611
+
2612
+ Handles patterns like:
2613
+ [action(x=item) for item in collection]
2614
+
2615
+ The comprehension must have exactly one generator with no conditions,
2616
+ and the element must be an action call.
2617
+
2618
+ Args:
2619
+ listcomp: The ListComp AST node
2620
+
2621
+ Returns:
2622
+ A SpreadExpr, or None if conversion fails.
2623
+ """
2624
+ # Only support simple list comprehensions with one generator
2625
+ if len(listcomp.generators) != 1:
2626
+ line = getattr(listcomp, "lineno", None)
2627
+ col = getattr(listcomp, "col_offset", None)
2628
+ raise UnsupportedPatternError(
2629
+ "Spread pattern only supports a single loop variable",
2630
+ "Use a simple list comprehension: [action(x) for x in items]",
2631
+ line=line,
2632
+ col=col,
2633
+ )
2634
+
2635
+ gen = listcomp.generators[0]
2636
+
2637
+ # Check for conditions - not supported
2638
+ if gen.ifs:
2639
+ line = getattr(listcomp, "lineno", None)
2640
+ col = getattr(listcomp, "col_offset", None)
2641
+ raise UnsupportedPatternError(
2642
+ "Spread pattern does not support conditions in list comprehension",
2643
+ "Remove the 'if' clause from the comprehension",
2644
+ line=line,
2645
+ col=col,
2646
+ )
2647
+
2648
+ # Get the loop variable name
2649
+ if not isinstance(gen.target, ast.Name):
2650
+ line = getattr(listcomp, "lineno", None)
2651
+ col = getattr(listcomp, "col_offset", None)
2652
+ raise UnsupportedPatternError(
2653
+ "Spread pattern requires a simple loop variable",
2654
+ "Use a simple variable: [action(x) for x in items]",
2655
+ line=line,
2656
+ col=col,
2657
+ )
2658
+ loop_var = gen.target.id
2659
+
2660
+ # Get the collection expression
2661
+ collection_expr = _expr_to_ir(gen.iter)
2662
+ if not collection_expr:
2663
+ line = getattr(listcomp, "lineno", None)
2664
+ col = getattr(listcomp, "col_offset", None)
2665
+ raise UnsupportedPatternError(
2666
+ "Could not convert collection expression in spread pattern",
2667
+ "Ensure the collection is a simple variable or expression",
2668
+ line=line,
2669
+ col=col,
2670
+ )
2671
+
2672
+ # The element must be an action call
2673
+ if not isinstance(listcomp.elt, ast.Call):
2674
+ line = getattr(listcomp, "lineno", None)
2675
+ col = getattr(listcomp, "col_offset", None)
2676
+ raise UnsupportedPatternError(
2677
+ "Spread pattern requires an action call in the list comprehension",
2678
+ "Use: [action(x=item) for item in items]",
2679
+ line=line,
2680
+ col=col,
2681
+ )
2682
+
2683
+ action_call = self._extract_action_call_from_call(listcomp.elt)
2684
+ if not action_call:
2685
+ line = getattr(listcomp, "lineno", None)
2686
+ col = getattr(listcomp, "col_offset", None)
2687
+ raise UnsupportedPatternError(
2688
+ "Spread pattern element must be an @action call",
2689
+ "Ensure the function is decorated with @action",
2690
+ line=line,
2691
+ col=col,
2692
+ )
2693
+
2694
+ # Build the SpreadExpr
2695
+ spread = ir.SpreadExpr()
2696
+ spread.collection.CopyFrom(collection_expr)
2697
+ spread.loop_var = loop_var
2698
+ spread.action.CopyFrom(action_call)
2699
+
2700
+ return spread
2701
+
2702
+ def _convert_gather_arg_to_call(self, node: ast.expr) -> Optional[ir.Call]:
2703
+ """Convert a gather argument to an IR Call.
2704
+
2705
+ Handles both action calls and regular function calls.
2706
+ """
2707
+ if not isinstance(node, ast.Call):
2708
+ return None
2709
+
2710
+ # Try to extract as an action call first
2711
+ action_call = self._extract_action_call_from_call(node)
2712
+ if action_call:
2713
+ call = ir.Call()
2714
+ call.action.CopyFrom(action_call)
2715
+ return call
2716
+
2717
+ # Fall back to regular function call
2718
+ func_call = self._convert_to_function_call(node)
2719
+ if func_call:
2720
+ call = ir.Call()
2721
+ call.function.CopyFrom(func_call)
2722
+ return call
2723
+
2724
+ return None
2725
+
2726
+ def _convert_to_function_call(self, node: ast.Call) -> Optional[ir.FunctionCall]:
2727
+ """Convert an AST Call to IR FunctionCall."""
2728
+ func_name = self._get_func_name(node.func)
2729
+ if not func_name:
2730
+ return None
2731
+
2732
+ fn_call = ir.FunctionCall(name=func_name)
2733
+
2734
+ # Add positional args
2735
+ for arg in node.args:
2736
+ expr = _expr_to_ir(arg)
2737
+ if expr:
2738
+ fn_call.args.append(expr)
2739
+
2740
+ # Add keyword args
2741
+ for kw in node.keywords:
2742
+ if kw.arg:
2743
+ expr = _expr_to_ir(kw.value)
2744
+ if expr:
2745
+ fn_call.kwargs.append(ir.Kwarg(name=kw.arg, value=expr))
2746
+
2747
+ return fn_call
2748
+
2749
+ def _get_func_name(self, node: ast.expr) -> Optional[str]:
2750
+ """Get function name from a func node."""
2751
+ if isinstance(node, ast.Name):
2752
+ return node.id
2753
+ elif isinstance(node, ast.Attribute):
2754
+ # Handle chained attributes like obj.method
2755
+ parts = []
2756
+ current = node
2757
+ while isinstance(current, ast.Attribute):
2758
+ parts.append(current.attr)
2759
+ current = current.value
2760
+ if isinstance(current, ast.Name):
2761
+ parts.append(current.id)
2762
+ name = ".".join(reversed(parts))
2763
+ if name.startswith("self."):
2764
+ return name[5:]
2765
+ return name
2766
+ return None
2767
+
2768
+ def _expr_to_ir_with_model_coercion(self, node: ast.expr) -> Optional[ir.Expr]:
2769
+ """Convert an AST expression to IR, converting model constructors to dicts.
2770
+
2771
+ This is used for action arguments where Pydantic models or dataclass
2772
+ constructors should be converted to dict expressions that Rust can evaluate.
2773
+
2774
+ If the expression is a model constructor (e.g., MyModel(field=value)),
2775
+ it is converted to a dict expression. Otherwise, falls back to the
2776
+ standard _expr_to_ir conversion.
2777
+ """
2778
+ # Check if this is a model constructor call
2779
+ if isinstance(node, ast.Call):
2780
+ model_name = self._is_model_constructor(node)
2781
+ if model_name:
2782
+ return self._convert_model_constructor_to_dict(node, model_name)
2783
+
2784
+ # Fall back to standard expression conversion
2785
+ return _expr_to_ir(node)
2786
+
2787
+ def _extract_action_call_from_awaitable(self, node: ast.expr) -> Optional[ir.ActionCall]:
2788
+ """Extract action call from an awaitable expression."""
2789
+ if isinstance(node, ast.Call):
2790
+ return self._extract_action_call_from_call(node)
2791
+ return None
2792
+
2793
+ def _extract_action_call_from_call(self, node: ast.Call) -> Optional[ir.ActionCall]:
2794
+ """Extract action call info from a Call node.
2795
+
2796
+ Converts positional arguments to keyword arguments using the action's
2797
+ signature introspection. This ensures all arguments are named in the IR.
2798
+
2799
+ Pydantic models and dataclass constructors passed as arguments are
2800
+ automatically converted to dict expressions.
2801
+ """
2802
+ action_name = self._get_action_name(node.func)
2803
+ if not action_name:
2804
+ return None
2805
+
2806
+ if action_name not in self._action_defs:
2807
+ return None
2808
+
2809
+ action_def = self._action_defs[action_name]
2810
+ action_call = ir.ActionCall(action_name=action_def.action_name)
2811
+
2812
+ # Set the module name so the worker knows where to find the action
2813
+ if action_def.module_name:
2814
+ action_call.module_name = action_def.module_name
2815
+
2816
+ # Get parameter names from signature for positional arg conversion
2817
+ param_names = list(action_def.signature.parameters.keys())
2818
+
2819
+ # Convert positional args to kwargs using signature introspection
2820
+ # Model constructors are converted to dict expressions
2821
+ for i, arg in enumerate(node.args):
2822
+ if i < len(param_names):
2823
+ expr = self._expr_to_ir_with_model_coercion(arg)
2824
+ if expr:
2825
+ kwarg = ir.Kwarg(name=param_names[i], value=expr)
2826
+ action_call.kwargs.append(kwarg)
2827
+
2828
+ # Add explicit kwargs
2829
+ # Model constructors are converted to dict expressions
2830
+ for kw in node.keywords:
2831
+ if kw.arg:
2832
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
2833
+ if expr:
2834
+ kwarg = ir.Kwarg(name=kw.arg, value=expr)
2835
+ action_call.kwargs.append(kwarg)
2836
+
2837
+ return action_call
2838
+
2839
+ def _get_action_name(self, func: ast.expr) -> Optional[str]:
2840
+ """Get the action name from a function expression."""
2841
+ if isinstance(func, ast.Name):
2842
+ return func.id
2843
+ elif isinstance(func, ast.Attribute):
2844
+ return func.attr
2845
+ return None
2846
+
2847
+ def _get_assign_target(self, targets: List[ast.expr]) -> Optional[str]:
2848
+ """Get the target variable name from assignment targets (single target only)."""
2849
+ if targets and isinstance(targets[0], ast.Name):
2850
+ return targets[0].id
2851
+ return None
2852
+
2853
+ def _get_assign_targets(self, targets: List[ast.expr]) -> List[str]:
2854
+ """Get all target variable names from assignment targets (including tuple unpacking)."""
2855
+ result: List[str] = []
2856
+ for t in targets:
2857
+ if isinstance(t, ast.Name):
2858
+ result.append(t.id)
2859
+ elif isinstance(t, ast.Subscript):
2860
+ formatted = _format_subscript_target(t)
2861
+ if formatted:
2862
+ result.append(formatted)
2863
+ elif isinstance(t, ast.Tuple):
2864
+ for elt in t.elts:
2865
+ if isinstance(elt, ast.Name):
2866
+ result.append(elt.id)
2867
+ return result
2868
+
2869
+
2870
+ def _make_span(node: ast.AST) -> ir.Span:
2871
+ """Create a Span from an AST node."""
2872
+ return ir.Span(
2873
+ start_line=getattr(node, "lineno", 0),
2874
+ start_col=getattr(node, "col_offset", 0),
2875
+ end_line=getattr(node, "end_lineno", 0) or 0,
2876
+ end_col=getattr(node, "end_col_offset", 0) or 0,
2877
+ )
2878
+
2879
+
2880
+ def _expr_to_ir(expr: ast.AST) -> Optional[ir.Expr]:
2881
+ """Convert Python AST expression to IR Expr."""
2882
+ result = ir.Expr(span=_make_span(expr))
2883
+
2884
+ if isinstance(expr, ast.Name):
2885
+ result.variable.CopyFrom(ir.Variable(name=expr.id))
2886
+ return result
2887
+
2888
+ if isinstance(expr, ast.Constant):
2889
+ literal = _constant_to_literal(expr.value)
2890
+ if literal:
2891
+ result.literal.CopyFrom(literal)
2892
+ return result
2893
+
2894
+ if isinstance(expr, ast.BinOp):
2895
+ left = _expr_to_ir(expr.left)
2896
+ right = _expr_to_ir(expr.right)
2897
+ op = _bin_op_to_ir(expr.op)
2898
+ if left and right and op:
2899
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
2900
+ return result
2901
+
2902
+ if isinstance(expr, ast.UnaryOp):
2903
+ operand = _expr_to_ir(expr.operand)
2904
+ op = _unary_op_to_ir(expr.op)
2905
+ if operand and op:
2906
+ result.unary_op.CopyFrom(ir.UnaryOp(op=op, operand=operand))
2907
+ return result
2908
+
2909
+ if isinstance(expr, ast.Compare):
2910
+ left = _expr_to_ir(expr.left)
2911
+ if not left:
2912
+ return None
2913
+ # For simplicity, handle single comparison
2914
+ if expr.ops and expr.comparators:
2915
+ op = _cmp_op_to_ir(expr.ops[0])
2916
+ right = _expr_to_ir(expr.comparators[0])
2917
+ if op and right:
2918
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
2919
+ return result
2920
+
2921
+ if isinstance(expr, ast.BoolOp):
2922
+ values = [_expr_to_ir(v) for v in expr.values]
2923
+ if all(v for v in values):
2924
+ op = _bool_op_to_ir(expr.op)
2925
+ if op and len(values) >= 2:
2926
+ # Chain boolean ops: a and b and c -> (a and b) and c
2927
+ result_expr = values[0]
2928
+ for v in values[1:]:
2929
+ if result_expr and v:
2930
+ new_result = ir.Expr()
2931
+ new_result.binary_op.CopyFrom(ir.BinaryOp(left=result_expr, op=op, right=v))
2932
+ result_expr = new_result
2933
+ return result_expr
2934
+
2935
+ if isinstance(expr, ast.List):
2936
+ elements = [_expr_to_ir(e) for e in expr.elts]
2937
+ if all(e for e in elements):
2938
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
2939
+ result.list.CopyFrom(list_expr)
2940
+ return result
2941
+
2942
+ if isinstance(expr, ast.Dict):
2943
+ entries: List[ir.DictEntry] = []
2944
+ for k, v in zip(expr.keys, expr.values, strict=False):
2945
+ if k:
2946
+ key_expr = _expr_to_ir(k)
2947
+ value_expr = _expr_to_ir(v)
2948
+ if key_expr and value_expr:
2949
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2950
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2951
+ return result
2952
+
2953
+ if isinstance(expr, ast.Subscript):
2954
+ obj = _expr_to_ir(expr.value)
2955
+ index = _expr_to_ir(expr.slice) if isinstance(expr.slice, ast.AST) else None
2956
+ if obj and index:
2957
+ result.index.CopyFrom(ir.IndexAccess(object=obj, index=index))
2958
+ return result
2959
+
2960
+ if isinstance(expr, ast.Attribute):
2961
+ obj = _expr_to_ir(expr.value)
2962
+ if obj:
2963
+ result.dot.CopyFrom(ir.DotAccess(object=obj, attribute=expr.attr))
2964
+ return result
2965
+
2966
+ if isinstance(expr, ast.Await) and isinstance(expr.value, ast.Call):
2967
+ func_name = _get_func_name(expr.value.func)
2968
+ if func_name:
2969
+ args = [_expr_to_ir(a) for a in expr.value.args]
2970
+ kwargs: List[ir.Kwarg] = []
2971
+ for kw in expr.value.keywords:
2972
+ if kw.arg:
2973
+ kw_expr = _expr_to_ir(kw.value)
2974
+ if kw_expr:
2975
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
2976
+ func_call = ir.FunctionCall(
2977
+ name=func_name,
2978
+ args=[a for a in args if a],
2979
+ kwargs=kwargs,
2980
+ )
2981
+ result.function_call.CopyFrom(func_call)
2982
+ return result
2983
+
2984
+ if isinstance(expr, ast.Call):
2985
+ # Function call
2986
+ func_name = _get_func_name(expr.func)
2987
+ if func_name:
2988
+ args = [_expr_to_ir(a) for a in expr.args]
2989
+ kwargs: List[ir.Kwarg] = []
2990
+ for kw in expr.keywords:
2991
+ if kw.arg:
2992
+ kw_expr = _expr_to_ir(kw.value)
2993
+ if kw_expr:
2994
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
2995
+ func_call = ir.FunctionCall(
2996
+ name=func_name,
2997
+ args=[a for a in args if a],
2998
+ kwargs=kwargs,
2999
+ )
3000
+ result.function_call.CopyFrom(func_call)
3001
+ return result
3002
+
3003
+ if isinstance(expr, ast.Tuple):
3004
+ # Handle tuple as list for now
3005
+ elements = [_expr_to_ir(e) for e in expr.elts]
3006
+ if all(e for e in elements):
3007
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
3008
+ result.list.CopyFrom(list_expr)
3009
+ return result
3010
+
3011
+ # Check for unsupported expression types
3012
+ _check_unsupported_expression(expr)
3013
+
3014
+ return None
3015
+
3016
+
3017
+ def _check_unsupported_expression(expr: ast.AST) -> None:
3018
+ """Check for unsupported expression types and raise descriptive errors."""
3019
+ line = getattr(expr, "lineno", None)
3020
+ col = getattr(expr, "col_offset", None)
3021
+
3022
+ if isinstance(expr, ast.Constant):
3023
+ if _constant_to_literal(expr.value) is None:
3024
+ raise UnsupportedPatternError(
3025
+ f"Unsupported literal type '{type(expr.value).__name__}'",
3026
+ RECOMMENDATIONS["unsupported_literal"],
3027
+ line=line,
3028
+ col=col,
3029
+ )
3030
+
3031
+ if isinstance(expr, ast.JoinedStr):
3032
+ raise UnsupportedPatternError(
3033
+ "F-strings are not supported",
3034
+ RECOMMENDATIONS["fstring"],
3035
+ line=line,
3036
+ col=col,
3037
+ )
3038
+ elif isinstance(expr, ast.Lambda):
3039
+ raise UnsupportedPatternError(
3040
+ "Lambda expressions are not supported",
3041
+ RECOMMENDATIONS["lambda"],
3042
+ line=line,
3043
+ col=col,
3044
+ )
3045
+ elif isinstance(expr, ast.ListComp):
3046
+ raise UnsupportedPatternError(
3047
+ "List comprehensions are not supported in this context",
3048
+ RECOMMENDATIONS["list_comprehension"],
3049
+ line=line,
3050
+ col=col,
3051
+ )
3052
+ elif isinstance(expr, ast.DictComp):
3053
+ raise UnsupportedPatternError(
3054
+ "Dict comprehensions are not supported in this context",
3055
+ RECOMMENDATIONS["dict_comprehension"],
3056
+ line=line,
3057
+ col=col,
3058
+ )
3059
+ elif isinstance(expr, ast.SetComp):
3060
+ raise UnsupportedPatternError(
3061
+ "Set comprehensions are not supported",
3062
+ RECOMMENDATIONS["set_comprehension"],
3063
+ line=line,
3064
+ col=col,
3065
+ )
3066
+ elif isinstance(expr, ast.GeneratorExp):
3067
+ raise UnsupportedPatternError(
3068
+ "Generator expressions are not supported",
3069
+ RECOMMENDATIONS["generator"],
3070
+ line=line,
3071
+ col=col,
3072
+ )
3073
+ elif isinstance(expr, ast.NamedExpr):
3074
+ raise UnsupportedPatternError(
3075
+ "The walrus operator (:=) is not supported",
3076
+ RECOMMENDATIONS["walrus"],
3077
+ line=line,
3078
+ col=col,
3079
+ )
3080
+ elif isinstance(expr, ast.Yield) or isinstance(expr, ast.YieldFrom):
3081
+ raise UnsupportedPatternError(
3082
+ "Yield expressions are not supported",
3083
+ RECOMMENDATIONS["yield_statement"],
3084
+ line=line,
3085
+ col=col,
3086
+ )
3087
+ elif isinstance(expr, ast.expr):
3088
+ raise UnsupportedPatternError(
3089
+ f"Unsupported expression type '{type(expr).__name__}'",
3090
+ RECOMMENDATIONS["unsupported_expression"],
3091
+ line=line,
3092
+ col=col,
3093
+ )
3094
+
3095
+
3096
+ def _format_subscript_target(target: ast.Subscript) -> Optional[str]:
3097
+ """Convert a subscript target to a string representation for tracking."""
3098
+ if not isinstance(target.value, ast.Name):
3099
+ return None
3100
+
3101
+ base = target.value.id
3102
+ try:
3103
+ index_str = ast.unparse(target.slice)
3104
+ except Exception:
3105
+ return None
3106
+
3107
+ return f"{base}[{index_str}]"
3108
+
3109
+
3110
+ def _constant_to_literal(value: Any) -> Optional[ir.Literal]:
3111
+ """Convert a Python constant to IR Literal."""
3112
+ literal = ir.Literal()
3113
+ if value is None:
3114
+ literal.is_none = True
3115
+ elif isinstance(value, bool):
3116
+ literal.bool_value = value
3117
+ elif isinstance(value, int):
3118
+ literal.int_value = value
3119
+ elif isinstance(value, float):
3120
+ literal.float_value = value
3121
+ elif isinstance(value, str):
3122
+ literal.string_value = value
3123
+ else:
3124
+ return None
3125
+ return literal
3126
+
3127
+
3128
+ def _bin_op_to_ir(op: ast.operator) -> Optional[ir.BinaryOperator]:
3129
+ """Convert Python binary operator to IR BinaryOperator."""
3130
+ mapping = {
3131
+ ast.Add: ir.BinaryOperator.BINARY_OP_ADD,
3132
+ ast.Sub: ir.BinaryOperator.BINARY_OP_SUB,
3133
+ ast.Mult: ir.BinaryOperator.BINARY_OP_MUL,
3134
+ ast.Div: ir.BinaryOperator.BINARY_OP_DIV,
3135
+ ast.FloorDiv: ir.BinaryOperator.BINARY_OP_FLOOR_DIV,
3136
+ ast.Mod: ir.BinaryOperator.BINARY_OP_MOD,
3137
+ }
3138
+ return mapping.get(type(op))
3139
+
3140
+
3141
+ def _unary_op_to_ir(op: ast.unaryop) -> Optional[ir.UnaryOperator]:
3142
+ """Convert Python unary operator to IR UnaryOperator."""
3143
+ mapping = {
3144
+ ast.USub: ir.UnaryOperator.UNARY_OP_NEG,
3145
+ ast.Not: ir.UnaryOperator.UNARY_OP_NOT,
3146
+ }
3147
+ return mapping.get(type(op))
3148
+
3149
+
3150
+ def _cmp_op_to_ir(op: ast.cmpop) -> Optional[ir.BinaryOperator]:
3151
+ """Convert Python comparison operator to IR BinaryOperator."""
3152
+ mapping = {
3153
+ ast.Eq: ir.BinaryOperator.BINARY_OP_EQ,
3154
+ ast.NotEq: ir.BinaryOperator.BINARY_OP_NE,
3155
+ ast.Lt: ir.BinaryOperator.BINARY_OP_LT,
3156
+ ast.LtE: ir.BinaryOperator.BINARY_OP_LE,
3157
+ ast.Gt: ir.BinaryOperator.BINARY_OP_GT,
3158
+ ast.GtE: ir.BinaryOperator.BINARY_OP_GE,
3159
+ ast.In: ir.BinaryOperator.BINARY_OP_IN,
3160
+ ast.NotIn: ir.BinaryOperator.BINARY_OP_NOT_IN,
3161
+ }
3162
+ return mapping.get(type(op))
3163
+
3164
+
3165
+ def _bool_op_to_ir(op: ast.boolop) -> Optional[ir.BinaryOperator]:
3166
+ """Convert Python boolean operator to IR BinaryOperator."""
3167
+ mapping = {
3168
+ ast.And: ir.BinaryOperator.BINARY_OP_AND,
3169
+ ast.Or: ir.BinaryOperator.BINARY_OP_OR,
3170
+ }
3171
+ return mapping.get(type(op))
3172
+
3173
+
3174
+ def _get_func_name(func: ast.expr) -> Optional[str]:
3175
+ """Get function name from a Call's func attribute."""
3176
+ if isinstance(func, ast.Name):
3177
+ return func.id
3178
+ elif isinstance(func, ast.Attribute):
3179
+ # For method calls like obj.method, return full dotted name
3180
+ parts = []
3181
+ current = func
3182
+ while isinstance(current, ast.Attribute):
3183
+ parts.append(current.attr)
3184
+ current = current.value
3185
+ if isinstance(current, ast.Name):
3186
+ parts.append(current.id)
3187
+ name = ".".join(reversed(parts))
3188
+ if name.startswith("self."):
3189
+ return name[5:]
3190
+ return name
3191
+ return None