rappel 0.5.5__py3-none-manylinux_2_39_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rappel might be problematic. Click here for more details.

rappel/ir_builder.py ADDED
@@ -0,0 +1,3168 @@
1
+ """
2
+ IR Builder - Converts Python workflow AST to Rappel IR (ast.proto).
3
+
4
+ This module parses Python workflow classes and produces the IR representation
5
+ that can be sent to the Rust runtime for execution.
6
+
7
+ The IR builder performs deep transformations to convert Python patterns into
8
+ valid Rappel IR structures. Control flow bodies (try, for, if branches) are
9
+ represented as full blocks, so they can contain arbitrary statements and
10
+ `return` behaves consistently with Python (returns from the entry function end
11
+ the workflow).
12
+
13
+ Validation:
14
+ The IR builder proactively detects unsupported Python patterns and raises
15
+ UnsupportedPatternError with clear recommendations for how to rewrite the code.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ import copy
22
+ import inspect
23
+ import textwrap
24
+ from dataclasses import dataclass
25
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
26
+
27
+ from proto import ast_pb2 as ir
28
+
29
+
30
+ class UnsupportedPatternError(Exception):
31
+ """Raised when the IR builder encounters an unsupported Python pattern.
32
+
33
+ This error includes a recommendation for how to rewrite the code to use
34
+ supported patterns.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ message: str,
40
+ recommendation: str,
41
+ line: Optional[int] = None,
42
+ col: Optional[int] = None,
43
+ filename: Optional[str] = None,
44
+ ):
45
+ self.message = message
46
+ self.recommendation = recommendation
47
+ self.line = line
48
+ self.col = col
49
+ self.filename = filename
50
+
51
+ location_parts: List[str] = []
52
+ if filename:
53
+ location_parts.append(filename)
54
+ if line:
55
+ location_parts.append(f"line {line}")
56
+ if col is not None:
57
+ location_parts.append(f"col {col}")
58
+ location = f" ({', '.join(location_parts)})" if location_parts else ""
59
+ full_message = f"{message}{location}\n\nRecommendation: {recommendation}"
60
+ super().__init__(full_message)
61
+
62
+
63
+ # Recommendations for common unsupported patterns
64
+ RECOMMENDATIONS = {
65
+ "constructor_return": (
66
+ "Returning a class constructor (like MyModel(...)) directly is not supported.\n"
67
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
68
+ "Use an @action to create the object:\n\n"
69
+ " @action\n"
70
+ " async def build_result(items: list, count: int) -> MyResult:\n"
71
+ " return MyResult(items=items, count=count)\n\n"
72
+ " # In workflow:\n"
73
+ " return await build_result(items, count)"
74
+ ),
75
+ "constructor_assignment": (
76
+ "Assigning a class constructor result (like x = MyClass(...)) is not supported.\n"
77
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
78
+ "Use an @action to create the object:\n\n"
79
+ " @action\n"
80
+ " async def create_config(value: int) -> Config:\n"
81
+ " return Config(value=value)\n\n"
82
+ " # In workflow:\n"
83
+ " config = await create_config(value)"
84
+ ),
85
+ "non_action_call": (
86
+ "Calling a function that is not decorated with @action is not supported.\n"
87
+ "Only @action decorated functions can be awaited in workflow code.\n\n"
88
+ "Add the @action decorator to your function:\n\n"
89
+ " @action\n"
90
+ " async def my_function(x: int) -> int:\n"
91
+ " return x * 2"
92
+ ),
93
+ "sync_function_call": (
94
+ "Calling a synchronous function directly in workflow code is not supported.\n"
95
+ "All computation must happen inside @action decorated async functions.\n\n"
96
+ "Wrap your logic in an @action:\n\n"
97
+ " @action\n"
98
+ " async def compute(x: int) -> int:\n"
99
+ " return some_sync_function(x)"
100
+ ),
101
+ "method_call_non_self": (
102
+ "Calling methods on objects other than 'self' is not supported in workflow code.\n"
103
+ "Use an @action to perform method calls:\n\n"
104
+ " @action\n"
105
+ " async def call_method(obj: MyClass) -> Result:\n"
106
+ " return obj.some_method()"
107
+ ),
108
+ "builtin_call": (
109
+ "Calling built-in functions like len(), str(), int() directly is not supported.\n"
110
+ "Use an @action to perform these operations:\n\n"
111
+ " @action\n"
112
+ " async def get_length(items: list) -> int:\n"
113
+ " return len(items)"
114
+ ),
115
+ "fstring": (
116
+ "F-strings are not supported in workflow code because they require "
117
+ "runtime string interpolation.\n"
118
+ "Use an @action to perform string formatting:\n\n"
119
+ " @action\n"
120
+ " async def format_message(value: int) -> str:\n"
121
+ " return f'Result: {value}'"
122
+ ),
123
+ "delete": (
124
+ "The 'del' statement is not supported in workflow code.\n"
125
+ "Use an @action to perform mutations:\n\n"
126
+ " @action\n"
127
+ " async def remove_key(data: dict, key: str) -> dict:\n"
128
+ " del data[key]\n"
129
+ " return data"
130
+ ),
131
+ "while_loop": (
132
+ "While loops are not supported in workflow code because they can run "
133
+ "indefinitely.\n"
134
+ "Use a for loop with a fixed range, or restructure as recursive workflow calls."
135
+ ),
136
+ "with_statement": (
137
+ "Context managers (with statements) are not supported in workflow code.\n"
138
+ "Use an @action to handle resource management:\n\n"
139
+ " @action\n"
140
+ " async def read_file(path: str) -> str:\n"
141
+ " with open(path) as f:\n"
142
+ " return f.read()"
143
+ ),
144
+ "raise_statement": (
145
+ "The 'raise' statement is not supported directly in workflow code.\n"
146
+ "Use an @action that raises exceptions, or return error values."
147
+ ),
148
+ "assert_statement": (
149
+ "Assert statements are not supported in workflow code.\n"
150
+ "Use an @action for validation, or use if statements with explicit error handling."
151
+ ),
152
+ "lambda": (
153
+ "Lambda expressions are not supported in workflow code.\n"
154
+ "Use an @action to define the function logic."
155
+ ),
156
+ "list_comprehension": (
157
+ "List comprehensions are only supported when assigned directly to a variable "
158
+ "or inside asyncio.gather(*[...]).\n"
159
+ "For other cases, use a for loop or an @action."
160
+ ),
161
+ "dict_comprehension": (
162
+ "Dict comprehensions are only supported when assigned directly to a variable.\n"
163
+ "For other cases, use a for loop or an @action:\n\n"
164
+ " @action\n"
165
+ " async def build_dict(items: list) -> dict:\n"
166
+ " return {k: v for k, v in items}"
167
+ ),
168
+ "set_comprehension": (
169
+ "Set comprehensions are not supported in workflow code.\nUse an @action to build sets."
170
+ ),
171
+ "generator": (
172
+ "Generator expressions are not supported in workflow code.\n"
173
+ "Use a list or an @action instead."
174
+ ),
175
+ "walrus": (
176
+ "The walrus operator (:=) is not supported in workflow code.\n"
177
+ "Use separate assignment statements instead."
178
+ ),
179
+ "match": (
180
+ "Match statements are not supported in workflow code.\nUse if/elif/else chains instead."
181
+ ),
182
+ "gather_variable_spread": (
183
+ "Spreading a variable in asyncio.gather() is not supported because it requires "
184
+ "data flow analysis to determine the contents.\n"
185
+ "Use a list comprehension directly in gather:\n\n"
186
+ " # Instead of:\n"
187
+ " tasks = []\n"
188
+ " for i in range(count):\n"
189
+ " tasks.append(process(value=i))\n"
190
+ " results = await asyncio.gather(*tasks)\n\n"
191
+ " # Use:\n"
192
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
193
+ ),
194
+ "for_loop_append_pattern": (
195
+ "Building a task list in a for loop then spreading in asyncio.gather() is not "
196
+ "supported.\n"
197
+ "Use a list comprehension directly in gather:\n\n"
198
+ " # Instead of:\n"
199
+ " tasks = []\n"
200
+ " for i in range(count):\n"
201
+ " tasks.append(process(value=i))\n"
202
+ " results = await asyncio.gather(*tasks)\n\n"
203
+ " # Use:\n"
204
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
205
+ ),
206
+ "global_statement": (
207
+ "Global statements are not supported in workflow code.\n"
208
+ "Use workflow state or pass values explicitly."
209
+ ),
210
+ "nonlocal_statement": (
211
+ "Nonlocal statements are not supported in workflow code.\n"
212
+ "Use explicit parameter passing instead."
213
+ ),
214
+ "import_statement": (
215
+ "Import statements inside workflow run() are not supported.\n"
216
+ "Place imports at the module level."
217
+ ),
218
+ "class_def": (
219
+ "Class definitions inside workflow run() are not supported.\n"
220
+ "Define classes at the module level."
221
+ ),
222
+ "function_def": (
223
+ "Nested function definitions inside workflow run() are not supported.\n"
224
+ "Define functions at the module level or use @action."
225
+ ),
226
+ "yield_statement": (
227
+ "Yield statements are not supported in workflow code.\n"
228
+ "Workflows must return a complete result, not generate values incrementally."
229
+ ),
230
+ "unsupported_expression": (
231
+ "This expression type is not supported in workflow code.\n"
232
+ "Move the logic into an @action or rewrite using supported expressions."
233
+ ),
234
+ "unsupported_literal": (
235
+ "This literal type is not supported in workflow code.\n"
236
+ "Convert the value to a supported literal type inside an @action."
237
+ ),
238
+ }
239
+
240
+
241
+ if TYPE_CHECKING:
242
+ from .workflow import Workflow
243
+
244
+
245
+ @dataclass
246
+ class ActionDefinition:
247
+ """Definition of an action function."""
248
+
249
+ action_name: str
250
+ module_name: Optional[str]
251
+ signature: inspect.Signature
252
+
253
+
254
+ @dataclass
255
+ class ModuleContext:
256
+ """Cached IRBuilder context derived from a module."""
257
+
258
+ action_defs: Dict[str, ActionDefinition]
259
+ imported_names: Dict[str, "ImportedName"]
260
+ module_functions: Set[str]
261
+ model_defs: Dict[str, "ModelDefinition"]
262
+
263
+
264
+ @dataclass
265
+ class TransformContext:
266
+ """Context for IR transformations."""
267
+
268
+ # Counter for generating unique function names
269
+ implicit_fn_counter: int = 0
270
+ # Implicit functions generated during transformation
271
+ implicit_functions: List[ir.FunctionDef] = None # type: ignore
272
+
273
+ def __post_init__(self) -> None:
274
+ if self.implicit_functions is None:
275
+ self.implicit_functions = []
276
+
277
+ def next_implicit_fn_name(self, prefix: str = "implicit") -> str:
278
+ """Generate a unique implicit function name."""
279
+ self.implicit_fn_counter += 1
280
+ return f"__{prefix}_{self.implicit_fn_counter}__"
281
+
282
+
283
+ def build_workflow_ir(workflow_cls: type["Workflow"]) -> ir.Program:
284
+ """Build an IR Program from a workflow class.
285
+
286
+ Args:
287
+ workflow_cls: The workflow class to convert.
288
+
289
+ Returns:
290
+ An IR Program proto message.
291
+ """
292
+ original_run = getattr(workflow_cls, "__workflow_run_impl__", None)
293
+ if original_run is None:
294
+ original_run = workflow_cls.__dict__.get("run")
295
+ if original_run is None:
296
+ raise ValueError(f"workflow {workflow_cls!r} missing run implementation")
297
+
298
+ module = inspect.getmodule(original_run)
299
+ if module is None:
300
+ raise ValueError(f"unable to locate module for workflow {workflow_cls!r}")
301
+
302
+ module_contexts: Dict[str, ModuleContext] = {}
303
+
304
+ def get_module_context(target_module: Any) -> ModuleContext:
305
+ module_name = target_module.__name__
306
+ if module_name not in module_contexts:
307
+ module_contexts[module_name] = ModuleContext(
308
+ action_defs=_discover_action_names(target_module),
309
+ imported_names=_discover_module_imports(target_module),
310
+ module_functions=_discover_module_functions(target_module),
311
+ model_defs=_discover_model_definitions(target_module),
312
+ )
313
+ return module_contexts[module_name]
314
+
315
+ # Build the IR with transformation context
316
+ ctx = TransformContext()
317
+ program = ir.Program()
318
+ function_defs: Dict[str, ir.FunctionDef] = {}
319
+
320
+ def parse_function(fn: Any) -> tuple[ast.AST, Optional[str], int]:
321
+ source_lines, start_line = inspect.getsourcelines(fn)
322
+ function_source = textwrap.dedent("".join(source_lines))
323
+ filename = inspect.getsourcefile(fn)
324
+ if filename is None:
325
+ filename = inspect.getfile(fn)
326
+ return ast.parse(function_source, filename=filename or "<unknown>"), filename, start_line
327
+
328
+ def _with_source_location(
329
+ err: UnsupportedPatternError,
330
+ filename: Optional[str],
331
+ start_line: int,
332
+ ) -> UnsupportedPatternError:
333
+ line = err.line
334
+ col = err.col
335
+ if line is not None:
336
+ line = start_line + line - 1
337
+ if col is not None:
338
+ col = col + 1
339
+ return UnsupportedPatternError(
340
+ err.message,
341
+ err.recommendation,
342
+ line=line,
343
+ col=col,
344
+ filename=filename,
345
+ )
346
+
347
+ def add_function_def(
348
+ fn: Any,
349
+ fn_tree: ast.AST,
350
+ filename: Optional[str],
351
+ start_line: int,
352
+ override_name: Optional[str] = None,
353
+ ) -> None:
354
+ fn_module = inspect.getmodule(fn)
355
+ if fn_module is None:
356
+ raise ValueError(f"unable to locate module for function {fn!r}")
357
+
358
+ ctx_data = get_module_context(fn_module)
359
+ builder = IRBuilder(
360
+ ctx_data.action_defs,
361
+ ctx,
362
+ ctx_data.imported_names,
363
+ ctx_data.module_functions,
364
+ ctx_data.model_defs,
365
+ )
366
+ try:
367
+ builder.visit(fn_tree)
368
+ except UnsupportedPatternError as err:
369
+ raise _with_source_location(err, filename, start_line) from err
370
+ if builder.function_def:
371
+ if override_name:
372
+ builder.function_def.name = override_name
373
+ function_defs[builder.function_def.name] = builder.function_def
374
+
375
+ # Build the entry function from run() and map it to main in the IR.
376
+ run_tree, run_filename, run_start_line = parse_function(original_run)
377
+ add_function_def(original_run, run_tree, run_filename, run_start_line, "main")
378
+
379
+ # Include helper methods reachable via self.method() calls
380
+ pending = list(_collect_self_method_calls(run_tree))
381
+ visited: Set[str] = set()
382
+ skip_methods = {"run_action"}
383
+
384
+ while pending:
385
+ method_name = pending.pop()
386
+ if method_name in visited or method_name == "run" or method_name in skip_methods:
387
+ continue
388
+ visited.add(method_name)
389
+
390
+ method = _find_workflow_method(workflow_cls, method_name)
391
+ if method is None:
392
+ continue
393
+
394
+ method_tree, method_filename, method_start_line = parse_function(method)
395
+ add_function_def(method, method_tree, method_filename, method_start_line)
396
+
397
+ pending.extend(_collect_self_method_calls(method_tree))
398
+
399
+ # Add implicit functions first (they may be called by the main function)
400
+ for implicit_fn in ctx.implicit_functions:
401
+ program.functions.append(implicit_fn)
402
+
403
+ # Add all function definitions (run + reachable helper methods)
404
+ for fn_def in function_defs.values():
405
+ program.functions.append(fn_def)
406
+
407
+ return program
408
+
409
+
410
+ def _collect_self_method_calls(tree: ast.AST) -> Set[str]:
411
+ """Collect self.method(...) call names from a parsed function AST."""
412
+ calls: Set[str] = set()
413
+ for node in ast.walk(tree):
414
+ if not isinstance(node, ast.Call):
415
+ continue
416
+ func = node.func
417
+ if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
418
+ if func.value.id == "self":
419
+ calls.add(func.attr)
420
+ return calls
421
+
422
+
423
+ def _find_workflow_method(workflow_cls: type["Workflow"], name: str) -> Optional[Any]:
424
+ """Find a workflow method by name across the class MRO."""
425
+ for base in workflow_cls.__mro__:
426
+ if name not in base.__dict__:
427
+ continue
428
+ value = base.__dict__[name]
429
+ if isinstance(value, staticmethod) or isinstance(value, classmethod):
430
+ return value.__func__
431
+ if inspect.isfunction(value):
432
+ return value
433
+ return None
434
+
435
+
436
+ def _discover_action_names(module: Any) -> Dict[str, ActionDefinition]:
437
+ """Discover all @action decorated functions in a module."""
438
+ names: Dict[str, ActionDefinition] = {}
439
+ for attr_name in dir(module):
440
+ attr = getattr(module, attr_name)
441
+ action_name = getattr(attr, "__rappel_action_name__", None)
442
+ action_module = getattr(attr, "__rappel_action_module__", None)
443
+ if callable(attr) and action_name:
444
+ signature = inspect.signature(attr)
445
+ names[attr_name] = ActionDefinition(
446
+ action_name=action_name,
447
+ module_name=action_module or module.__name__,
448
+ signature=signature,
449
+ )
450
+ return names
451
+
452
+
453
+ def _discover_module_functions(module: Any) -> Set[str]:
454
+ """Discover all async function names defined in a module.
455
+
456
+ This is used to detect when users await functions in the same module
457
+ that are NOT decorated with @action.
458
+ """
459
+ function_names: Set[str] = set()
460
+ for attr_name in dir(module):
461
+ try:
462
+ attr = getattr(module, attr_name)
463
+ except AttributeError:
464
+ continue
465
+
466
+ # Only include functions defined in THIS module (not imported)
467
+ if not callable(attr):
468
+ continue
469
+ if not inspect.iscoroutinefunction(attr):
470
+ continue
471
+
472
+ # Check if the function is defined in this module
473
+ func_module = getattr(attr, "__module__", None)
474
+ if func_module == module.__name__:
475
+ function_names.add(attr_name)
476
+
477
+ return function_names
478
+
479
+
480
+ @dataclass
481
+ class ImportedName:
482
+ """Tracks an imported name and its source module."""
483
+
484
+ local_name: str # Name used in code (e.g., "sleep")
485
+ module: str # Source module (e.g., "asyncio")
486
+ original_name: str # Original name in source module (e.g., "sleep")
487
+
488
+
489
+ @dataclass
490
+ class ModelFieldDefinition:
491
+ """Definition of a field in a Pydantic model or dataclass."""
492
+
493
+ name: str
494
+ has_default: bool
495
+ default_value: Any = None # Only set if has_default is True
496
+
497
+
498
+ @dataclass
499
+ class ModelDefinition:
500
+ """Definition of a Pydantic model or dataclass that can be used in workflows.
501
+
502
+ These are data classes that can be instantiated in workflow code and will
503
+ be converted to dictionary expressions in the IR.
504
+ """
505
+
506
+ class_name: str
507
+ fields: Dict[str, ModelFieldDefinition]
508
+ is_pydantic: bool # True for Pydantic models, False for dataclasses
509
+
510
+
511
+ def _is_simple_pydantic_model(cls: type) -> bool:
512
+ """Check if a class is a simple Pydantic model without custom validators.
513
+
514
+ A simple Pydantic model:
515
+ - Inherits from pydantic.BaseModel
516
+ - Has no field_validator or model_validator decorators
517
+ - Has no custom __init__ method
518
+
519
+ Returns False if pydantic is not installed or cls is not a Pydantic model.
520
+ """
521
+ try:
522
+ from pydantic import BaseModel
523
+ except ImportError:
524
+ return False
525
+
526
+ if not isinstance(cls, type) or not issubclass(cls, BaseModel):
527
+ return False
528
+
529
+ # Check for validators - Pydantic v2 uses __pydantic_decorators__
530
+ decorators = getattr(cls, "__pydantic_decorators__", None)
531
+ if decorators is not None:
532
+ # Check for field validators
533
+ if hasattr(decorators, "field_validators") and decorators.field_validators:
534
+ return False
535
+ # Check for model validators
536
+ if hasattr(decorators, "model_validators") and decorators.model_validators:
537
+ return False
538
+
539
+ # Check for custom __init__ (not the one from BaseModel)
540
+ if "__init__" in cls.__dict__:
541
+ return False
542
+
543
+ return True
544
+
545
+
546
+ def _is_simple_dataclass(cls: type) -> bool:
547
+ """Check if a class is a simple dataclass without custom logic.
548
+
549
+ A simple dataclass:
550
+ - Is decorated with @dataclass
551
+ - Has no custom __init__ method (uses the generated one)
552
+ - Has no __post_init__ method
553
+
554
+ Returns False if cls is not a dataclass.
555
+ """
556
+ import dataclasses
557
+
558
+ if not dataclasses.is_dataclass(cls):
559
+ return False
560
+
561
+ # Check for __post_init__ which could have custom logic
562
+ if hasattr(cls, "__post_init__") and "__post_init__" in cls.__dict__:
563
+ return False
564
+
565
+ # Dataclasses generate __init__, so check if there's a custom one
566
+ # that overrides it (unlikely but possible)
567
+ # The dataclass decorator sets __init__ so we can't easily detect override
568
+ # We'll trust that dataclasses without __post_init__ are simple
569
+
570
+ return True
571
+
572
+
573
+ def _get_pydantic_model_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
574
+ """Extract field definitions from a Pydantic model."""
575
+ try:
576
+ from pydantic import BaseModel
577
+ from pydantic.fields import FieldInfo
578
+ except ImportError:
579
+ return {}
580
+
581
+ if not issubclass(cls, BaseModel):
582
+ return {}
583
+
584
+ fields: Dict[str, ModelFieldDefinition] = {}
585
+
586
+ # Pydantic v2 uses model_fields
587
+ model_fields = getattr(cls, "model_fields", {})
588
+ for field_name, field_info in model_fields.items():
589
+ has_default = False
590
+ default_value = None
591
+
592
+ if isinstance(field_info, FieldInfo):
593
+ # Check if field has a default value
594
+ # PydanticUndefined means no default
595
+ from pydantic_core import PydanticUndefined
596
+
597
+ if field_info.default is not PydanticUndefined:
598
+ has_default = True
599
+ default_value = field_info.default
600
+ elif field_info.default_factory is not None:
601
+ # We can't serialize factory functions, so treat as no default
602
+ has_default = False
603
+
604
+ fields[field_name] = ModelFieldDefinition(
605
+ name=field_name,
606
+ has_default=has_default,
607
+ default_value=default_value,
608
+ )
609
+
610
+ return fields
611
+
612
+
613
+ def _get_dataclass_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
614
+ """Extract field definitions from a dataclass."""
615
+ import dataclasses
616
+
617
+ if not dataclasses.is_dataclass(cls):
618
+ return {}
619
+
620
+ fields: Dict[str, ModelFieldDefinition] = {}
621
+ for field in dataclasses.fields(cls):
622
+ has_default = False
623
+ default_value = None
624
+
625
+ if field.default is not dataclasses.MISSING:
626
+ has_default = True
627
+ default_value = field.default
628
+ elif field.default_factory is not dataclasses.MISSING:
629
+ # We can't serialize factory functions, so treat as no default
630
+ has_default = False
631
+
632
+ fields[field.name] = ModelFieldDefinition(
633
+ name=field.name,
634
+ has_default=has_default,
635
+ default_value=default_value,
636
+ )
637
+
638
+ return fields
639
+
640
+
641
+ def _discover_model_definitions(module: Any) -> Dict[str, ModelDefinition]:
642
+ """Discover all Pydantic models and dataclasses that can be used in workflows.
643
+
644
+ Only discovers "simple" models without custom validators or __post_init__.
645
+ """
646
+ models: Dict[str, ModelDefinition] = {}
647
+
648
+ for attr_name in dir(module):
649
+ try:
650
+ attr = getattr(module, attr_name)
651
+ except AttributeError:
652
+ continue
653
+
654
+ if not isinstance(attr, type):
655
+ continue
656
+
657
+ # Check if this class is defined in this module or imported
658
+ # We want to include both, as models might be imported
659
+ if _is_simple_pydantic_model(attr):
660
+ fields = _get_pydantic_model_fields(attr)
661
+ models[attr_name] = ModelDefinition(
662
+ class_name=attr_name,
663
+ fields=fields,
664
+ is_pydantic=True,
665
+ )
666
+ elif _is_simple_dataclass(attr):
667
+ fields = _get_dataclass_fields(attr)
668
+ models[attr_name] = ModelDefinition(
669
+ class_name=attr_name,
670
+ fields=fields,
671
+ is_pydantic=False,
672
+ )
673
+
674
+ return models
675
+
676
+
677
+ def _discover_module_imports(module: Any) -> Dict[str, ImportedName]:
678
+ """Discover imports in a module by parsing its source.
679
+
680
+ Tracks imports like:
681
+ - from asyncio import sleep -> {"sleep": ImportedName("sleep", "asyncio", "sleep")}
682
+ - from asyncio import sleep as s -> {"s": ImportedName("s", "asyncio", "sleep")}
683
+ """
684
+ imported: Dict[str, ImportedName] = {}
685
+
686
+ try:
687
+ source = inspect.getsource(module)
688
+ tree = ast.parse(source)
689
+ except (OSError, TypeError):
690
+ # Can't get source (e.g., built-in module)
691
+ return imported
692
+
693
+ for node in ast.walk(tree):
694
+ if isinstance(node, ast.ImportFrom) and node.module:
695
+ for alias in node.names:
696
+ local_name = alias.asname if alias.asname else alias.name
697
+ imported[local_name] = ImportedName(
698
+ local_name=local_name,
699
+ module=node.module,
700
+ original_name=alias.name,
701
+ )
702
+
703
+ return imported
704
+
705
+
706
+ class IRBuilder(ast.NodeVisitor):
707
+ """Builds IR from Python AST with deep transformations."""
708
+
709
+ def __init__(
710
+ self,
711
+ action_defs: Dict[str, ActionDefinition],
712
+ ctx: TransformContext,
713
+ imported_names: Optional[Dict[str, ImportedName]] = None,
714
+ module_functions: Optional[Set[str]] = None,
715
+ model_defs: Optional[Dict[str, ModelDefinition]] = None,
716
+ ):
717
+ self._action_defs = action_defs
718
+ self._ctx = ctx
719
+ self._imported_names = imported_names or {}
720
+ self._module_functions = module_functions or set()
721
+ self._model_defs = model_defs or {}
722
+ self.function_def: Optional[ir.FunctionDef] = None
723
+ self._statements: List[ir.Statement] = []
724
+
725
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
726
+ """Visit a function definition (the workflow's run method)."""
727
+ # Extract inputs from function parameters (skip 'self')
728
+ inputs: List[str] = []
729
+ for arg in node.args.args[1:]: # Skip 'self'
730
+ inputs.append(arg.arg)
731
+
732
+ # Create the function definition
733
+ self.function_def = ir.FunctionDef(
734
+ name=node.name,
735
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
736
+ span=_make_span(node),
737
+ )
738
+
739
+ # Visit the body - _visit_statement now returns a list
740
+ self._statements = []
741
+ for stmt in node.body:
742
+ ir_stmts = self._visit_statement(stmt)
743
+ self._statements.extend(ir_stmts)
744
+
745
+ # Set the body
746
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
747
+
748
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
749
+ """Visit an async function definition (the workflow's run method)."""
750
+ # Handle async the same way as sync for IR building
751
+ inputs: List[str] = []
752
+ for arg in node.args.args[1:]: # Skip 'self'
753
+ inputs.append(arg.arg)
754
+
755
+ self.function_def = ir.FunctionDef(
756
+ name=node.name,
757
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
758
+ span=_make_span(node),
759
+ )
760
+
761
+ self._statements = []
762
+ for stmt in node.body:
763
+ ir_stmts = self._visit_statement(stmt)
764
+ self._statements.extend(ir_stmts)
765
+
766
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
767
+
768
+ def _visit_statement(self, node: ast.stmt) -> List[ir.Statement]:
769
+ """Convert a Python statement to IR Statement(s).
770
+
771
+ Returns a list because some transformations (like try block hoisting)
772
+ may expand a single Python statement into multiple IR statements.
773
+
774
+ Raises UnsupportedPatternError for unsupported statement types.
775
+ """
776
+ if isinstance(node, ast.Assign):
777
+ dict_expanded = self._expand_dict_comprehension_assignment(node)
778
+ if dict_expanded is not None:
779
+ return dict_expanded
780
+ expanded = self._expand_list_comprehension_assignment(node)
781
+ if expanded is not None:
782
+ return expanded
783
+ result = self._visit_assign(node)
784
+ return [result] if result else []
785
+ elif isinstance(node, ast.Expr):
786
+ result = self._visit_expr_stmt(node)
787
+ return [result] if result else []
788
+ elif isinstance(node, ast.For):
789
+ return self._visit_for(node)
790
+ elif isinstance(node, ast.If):
791
+ return self._visit_if(node)
792
+ elif isinstance(node, ast.Try):
793
+ return self._visit_try(node)
794
+ elif isinstance(node, ast.Return):
795
+ return self._visit_return(node)
796
+ elif isinstance(node, ast.AugAssign):
797
+ return self._visit_aug_assign(node)
798
+ elif isinstance(node, ast.Pass):
799
+ # Pass statements are fine, they just don't produce IR
800
+ return []
801
+
802
+ # Check for unsupported statement types
803
+ self._check_unsupported_statement(node)
804
+
805
+ return []
806
+
807
+ def _check_unsupported_statement(self, node: ast.stmt) -> None:
808
+ """Check for unsupported statement types and raise descriptive errors."""
809
+ line = getattr(node, "lineno", None)
810
+ col = getattr(node, "col_offset", None)
811
+
812
+ if isinstance(node, ast.While):
813
+ raise UnsupportedPatternError(
814
+ "While loops are not supported",
815
+ RECOMMENDATIONS["while_loop"],
816
+ line=line,
817
+ col=col,
818
+ )
819
+ elif isinstance(node, (ast.With, ast.AsyncWith)):
820
+ raise UnsupportedPatternError(
821
+ "Context managers (with statements) are not supported",
822
+ RECOMMENDATIONS["with_statement"],
823
+ line=line,
824
+ col=col,
825
+ )
826
+ elif isinstance(node, ast.Raise):
827
+ raise UnsupportedPatternError(
828
+ "The 'raise' statement is not supported",
829
+ RECOMMENDATIONS["raise_statement"],
830
+ line=line,
831
+ col=col,
832
+ )
833
+ elif isinstance(node, ast.Assert):
834
+ raise UnsupportedPatternError(
835
+ "Assert statements are not supported",
836
+ RECOMMENDATIONS["assert_statement"],
837
+ line=line,
838
+ col=col,
839
+ )
840
+ elif isinstance(node, ast.Delete):
841
+ raise UnsupportedPatternError(
842
+ "The 'del' statement is not supported",
843
+ RECOMMENDATIONS["delete"],
844
+ line=line,
845
+ col=col,
846
+ )
847
+ elif isinstance(node, ast.Global):
848
+ raise UnsupportedPatternError(
849
+ "Global statements are not supported",
850
+ RECOMMENDATIONS["global_statement"],
851
+ line=line,
852
+ col=col,
853
+ )
854
+ elif isinstance(node, ast.Nonlocal):
855
+ raise UnsupportedPatternError(
856
+ "Nonlocal statements are not supported",
857
+ RECOMMENDATIONS["nonlocal_statement"],
858
+ line=line,
859
+ col=col,
860
+ )
861
+ elif isinstance(node, (ast.Import, ast.ImportFrom)):
862
+ raise UnsupportedPatternError(
863
+ "Import statements inside workflow run() are not supported",
864
+ RECOMMENDATIONS["import_statement"],
865
+ line=line,
866
+ col=col,
867
+ )
868
+ elif isinstance(node, ast.ClassDef):
869
+ raise UnsupportedPatternError(
870
+ "Class definitions inside workflow run() are not supported",
871
+ RECOMMENDATIONS["class_def"],
872
+ line=line,
873
+ col=col,
874
+ )
875
+ elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
876
+ raise UnsupportedPatternError(
877
+ "Nested function definitions are not supported",
878
+ RECOMMENDATIONS["function_def"],
879
+ line=line,
880
+ col=col,
881
+ )
882
+ elif hasattr(ast, "Match") and isinstance(node, ast.Match):
883
+ raise UnsupportedPatternError(
884
+ "Match statements are not supported",
885
+ RECOMMENDATIONS["match"],
886
+ line=line,
887
+ col=col,
888
+ )
889
+
890
+ def _expand_list_comprehension_assignment(
891
+ self, node: ast.Assign
892
+ ) -> Optional[List[ir.Statement]]:
893
+ """Expand a list comprehension assignment into loop-based statements.
894
+
895
+ Example:
896
+ active_users = [user for user in users if user.active]
897
+
898
+ Becomes:
899
+ active_users = []
900
+ for user in users:
901
+ if user.active:
902
+ active_users = active_users + [user]
903
+ """
904
+ if not isinstance(node.value, ast.ListComp):
905
+ return None
906
+
907
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
908
+ line = getattr(node, "lineno", None)
909
+ col = getattr(node, "col_offset", None)
910
+ raise UnsupportedPatternError(
911
+ "List comprehension assignments must target a single variable",
912
+ "Assign the comprehension to a simple variable like `results = [x for x in items]`",
913
+ line=line,
914
+ col=col,
915
+ )
916
+
917
+ listcomp = node.value
918
+ if len(listcomp.generators) != 1:
919
+ line = getattr(listcomp, "lineno", None)
920
+ col = getattr(listcomp, "col_offset", None)
921
+ raise UnsupportedPatternError(
922
+ "List comprehensions with multiple generators are not supported",
923
+ "Use nested for loops instead of combining multiple generators in one comprehension",
924
+ line=line,
925
+ col=col,
926
+ )
927
+
928
+ gen = listcomp.generators[0]
929
+ if gen.is_async:
930
+ line = getattr(listcomp, "lineno", None)
931
+ col = getattr(listcomp, "col_offset", None)
932
+ raise UnsupportedPatternError(
933
+ "Async list comprehensions are not supported",
934
+ "Rewrite using an explicit async for loop",
935
+ line=line,
936
+ col=col,
937
+ )
938
+
939
+ target_name = node.targets[0].id
940
+
941
+ # Initialize the accumulator list: active_users = []
942
+ init_assign_ast = ast.Assign(
943
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
944
+ value=ast.List(elts=[], ctx=ast.Load()),
945
+ type_comment=None,
946
+ )
947
+ ast.copy_location(init_assign_ast, node)
948
+ ast.fix_missing_locations(init_assign_ast)
949
+
950
+ def _make_append_assignment(value_expr: ast.expr) -> ast.Assign:
951
+ append_assign = ast.Assign(
952
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
953
+ value=ast.BinOp(
954
+ left=ast.Name(id=target_name, ctx=ast.Load()),
955
+ op=ast.Add(),
956
+ right=ast.List(elts=[copy.deepcopy(value_expr)], ctx=ast.Load()),
957
+ ),
958
+ type_comment=None,
959
+ )
960
+ ast.copy_location(append_assign, node.value)
961
+ ast.fix_missing_locations(append_assign)
962
+ return append_assign
963
+
964
+ append_statements: List[ast.stmt] = []
965
+ if isinstance(listcomp.elt, ast.IfExp):
966
+ then_assign = _make_append_assignment(listcomp.elt.body)
967
+ else_assign = _make_append_assignment(listcomp.elt.orelse)
968
+ branch_if = ast.If(
969
+ test=copy.deepcopy(listcomp.elt.test),
970
+ body=[then_assign],
971
+ orelse=[else_assign],
972
+ )
973
+ ast.copy_location(branch_if, listcomp.elt)
974
+ ast.fix_missing_locations(branch_if)
975
+ append_statements.append(branch_if)
976
+ else:
977
+ append_statements.append(_make_append_assignment(listcomp.elt))
978
+
979
+ loop_body: List[ast.stmt] = append_statements
980
+ if gen.ifs:
981
+ condition: ast.expr
982
+ if len(gen.ifs) == 1:
983
+ condition = copy.deepcopy(gen.ifs[0])
984
+ else:
985
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
986
+ ast.copy_location(condition, gen.ifs[0])
987
+ if_stmt = ast.If(test=condition, body=append_statements, orelse=[])
988
+ ast.copy_location(if_stmt, gen.ifs[0])
989
+ ast.fix_missing_locations(if_stmt)
990
+ loop_body = [if_stmt]
991
+
992
+ loop_ast = ast.For(
993
+ target=copy.deepcopy(gen.target),
994
+ iter=copy.deepcopy(gen.iter),
995
+ body=loop_body,
996
+ orelse=[],
997
+ type_comment=None,
998
+ )
999
+ ast.copy_location(loop_ast, node)
1000
+ ast.fix_missing_locations(loop_ast)
1001
+
1002
+ statements: List[ir.Statement] = []
1003
+ init_stmt = self._visit_assign(init_assign_ast)
1004
+ if init_stmt:
1005
+ statements.append(init_stmt)
1006
+ statements.extend(self._visit_for(loop_ast))
1007
+
1008
+ return statements
1009
+
1010
+ def _expand_dict_comprehension_assignment(
1011
+ self, node: ast.Assign
1012
+ ) -> Optional[List[ir.Statement]]:
1013
+ """Expand a dict comprehension assignment into loop-based statements."""
1014
+ if not isinstance(node.value, ast.DictComp):
1015
+ return None
1016
+
1017
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
1018
+ line = getattr(node, "lineno", None)
1019
+ col = getattr(node, "col_offset", None)
1020
+ raise UnsupportedPatternError(
1021
+ "Dict comprehension assignments must target a single variable",
1022
+ "Assign the comprehension to a simple variable like `result = {k: v for k, v in pairs}`",
1023
+ line=line,
1024
+ col=col,
1025
+ )
1026
+
1027
+ dictcomp = node.value
1028
+ if len(dictcomp.generators) != 1:
1029
+ line = getattr(dictcomp, "lineno", None)
1030
+ col = getattr(dictcomp, "col_offset", None)
1031
+ raise UnsupportedPatternError(
1032
+ "Dict comprehensions with multiple generators are not supported",
1033
+ "Use nested for loops instead of combining multiple generators in one comprehension",
1034
+ line=line,
1035
+ col=col,
1036
+ )
1037
+
1038
+ gen = dictcomp.generators[0]
1039
+ if gen.is_async:
1040
+ line = getattr(dictcomp, "lineno", None)
1041
+ col = getattr(dictcomp, "col_offset", None)
1042
+ raise UnsupportedPatternError(
1043
+ "Async dict comprehensions are not supported",
1044
+ "Rewrite using an explicit async for loop",
1045
+ line=line,
1046
+ col=col,
1047
+ )
1048
+
1049
+ target_name = node.targets[0].id
1050
+
1051
+ # Initialize accumulator: result = {}
1052
+ init_assign_ast = ast.Assign(
1053
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
1054
+ value=ast.Dict(keys=[], values=[]),
1055
+ type_comment=None,
1056
+ )
1057
+ ast.copy_location(init_assign_ast, node)
1058
+ ast.fix_missing_locations(init_assign_ast)
1059
+
1060
+ # result[key] = value
1061
+ subscript_target = ast.Subscript(
1062
+ value=ast.Name(id=target_name, ctx=ast.Load()),
1063
+ slice=copy.deepcopy(dictcomp.key),
1064
+ ctx=ast.Store(),
1065
+ )
1066
+ append_assign_ast = ast.Assign(
1067
+ targets=[subscript_target],
1068
+ value=copy.deepcopy(dictcomp.value),
1069
+ type_comment=None,
1070
+ )
1071
+ ast.copy_location(append_assign_ast, node.value)
1072
+ ast.fix_missing_locations(append_assign_ast)
1073
+
1074
+ loop_body: List[ast.stmt] = []
1075
+ if gen.ifs:
1076
+ condition: ast.expr
1077
+ if len(gen.ifs) == 1:
1078
+ condition = copy.deepcopy(gen.ifs[0])
1079
+ else:
1080
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
1081
+ ast.copy_location(condition, gen.ifs[0])
1082
+ if_stmt = ast.If(test=condition, body=[append_assign_ast], orelse=[])
1083
+ ast.copy_location(if_stmt, gen.ifs[0])
1084
+ ast.fix_missing_locations(if_stmt)
1085
+ loop_body.append(if_stmt)
1086
+ else:
1087
+ loop_body.append(append_assign_ast)
1088
+
1089
+ loop_ast = ast.For(
1090
+ target=copy.deepcopy(gen.target),
1091
+ iter=copy.deepcopy(gen.iter),
1092
+ body=loop_body,
1093
+ orelse=[],
1094
+ type_comment=None,
1095
+ )
1096
+ ast.copy_location(loop_ast, node)
1097
+ ast.fix_missing_locations(loop_ast)
1098
+
1099
+ statements: List[ir.Statement] = []
1100
+ init_stmt = self._visit_assign(init_assign_ast)
1101
+ if init_stmt:
1102
+ statements.append(init_stmt)
1103
+ statements.extend(self._visit_for(loop_ast))
1104
+
1105
+ return statements
1106
+
1107
+ def _visit_assign(self, node: ast.Assign) -> Optional[ir.Statement]:
1108
+ """Convert assignment to IR.
1109
+
1110
+ All assignments with targets use the Assignment statement type.
1111
+ This provides uniform unpacking support for:
1112
+ - Action calls: a, b = @get_pair()
1113
+ - Parallel blocks: a, b = parallel: @x() @y()
1114
+ - Regular expressions: a, b = some_list
1115
+
1116
+ Raises UnsupportedPatternError for:
1117
+ - Constructor calls: x = MyClass(...)
1118
+ - Non-action await: x = await some_func()
1119
+ """
1120
+ stmt = ir.Statement(span=_make_span(node))
1121
+ targets = self._get_assign_targets(node.targets)
1122
+
1123
+ # Check for Pydantic model or dataclass constructor calls
1124
+ # These are converted to dict expressions
1125
+ model_name = self._is_model_constructor(node.value)
1126
+ if model_name and isinstance(node.value, ast.Call):
1127
+ value_expr = self._convert_model_constructor_to_dict(node.value, model_name)
1128
+ assign = ir.Assignment(targets=targets, value=value_expr)
1129
+ stmt.assignment.CopyFrom(assign)
1130
+ return stmt
1131
+
1132
+ # Check for constructor calls in assignment (e.g., x = MyModel(...))
1133
+ # This must come AFTER the model constructor check since models are allowed
1134
+ self._check_constructor_in_assignment(node.value)
1135
+
1136
+ # Check for asyncio.gather() - convert to parallel or spread expression
1137
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1138
+ gather_call = node.value.value
1139
+ if self._is_asyncio_gather_call(gather_call):
1140
+ gather_result = self._convert_asyncio_gather(gather_call)
1141
+ if gather_result is not None:
1142
+ if isinstance(gather_result, ir.ParallelExpr):
1143
+ value = ir.Expr(parallel_expr=gather_result, span=_make_span(node))
1144
+ else:
1145
+ # SpreadExpr
1146
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1147
+ assign = ir.Assignment(targets=targets, value=value)
1148
+ stmt.assignment.CopyFrom(assign)
1149
+ return stmt
1150
+
1151
+ # Check if this is an action call - wrap in Assignment for uniform unpacking
1152
+ action_call = self._extract_action_call(node.value)
1153
+ if action_call:
1154
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1155
+ assign = ir.Assignment(targets=targets, value=value)
1156
+ stmt.assignment.CopyFrom(assign)
1157
+ return stmt
1158
+
1159
+ # Regular assignment (variables, literals, expressions)
1160
+ value_expr = _expr_to_ir(node.value)
1161
+ if value_expr:
1162
+ assign = ir.Assignment(targets=targets, value=value_expr)
1163
+ stmt.assignment.CopyFrom(assign)
1164
+ return stmt
1165
+
1166
+ return None
1167
+
1168
+ def _visit_expr_stmt(self, node: ast.Expr) -> Optional[ir.Statement]:
1169
+ """Convert expression statement to IR (side effect only, no assignment)."""
1170
+ stmt = ir.Statement(span=_make_span(node))
1171
+
1172
+ # Check for asyncio.gather() - convert to parallel block statement (side effect)
1173
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1174
+ gather_call = node.value.value
1175
+ if self._is_asyncio_gather_call(gather_call):
1176
+ gather_result = self._convert_asyncio_gather(gather_call)
1177
+ if gather_result is not None:
1178
+ if isinstance(gather_result, ir.ParallelExpr):
1179
+ # Side effect only - use ParallelBlock statement
1180
+ parallel = ir.ParallelBlock()
1181
+ parallel.calls.extend(gather_result.calls)
1182
+ stmt.parallel_block.CopyFrom(parallel)
1183
+ return stmt
1184
+ else:
1185
+ # SpreadExpr as side effect - wrap in assignment with no targets
1186
+ # This handles: await asyncio.gather(*[action(x) for x in items])
1187
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1188
+ assign = ir.Assignment(targets=[], value=value)
1189
+ stmt.assignment.CopyFrom(assign)
1190
+ return stmt
1191
+
1192
+ # Check if this is an action call (side effect only)
1193
+ action_call = self._extract_action_call(node.value)
1194
+ if action_call:
1195
+ stmt.action_call.CopyFrom(action_call)
1196
+ return stmt
1197
+
1198
+ # Convert list.append(x) to list = list + [x]
1199
+ # This makes the mutation explicit so data flows correctly through the DAG
1200
+ if isinstance(node.value, ast.Call):
1201
+ call = node.value
1202
+ if (
1203
+ isinstance(call.func, ast.Attribute)
1204
+ and call.func.attr == "append"
1205
+ and isinstance(call.func.value, ast.Name)
1206
+ and len(call.args) == 1
1207
+ ):
1208
+ list_name = call.func.value.id
1209
+ append_value = call.args[0]
1210
+ # Create: list = list + [value]
1211
+ list_var = ir.Expr(variable=ir.Variable(name=list_name), span=_make_span(node))
1212
+ value_expr = _expr_to_ir(append_value)
1213
+ if value_expr:
1214
+ # Create [value] as a list literal
1215
+ list_literal = ir.Expr(
1216
+ list=ir.ListExpr(elements=[value_expr]), span=_make_span(node)
1217
+ )
1218
+ # Create list + [value]
1219
+ concat_expr = ir.Expr(
1220
+ binary_op=ir.BinaryOp(
1221
+ op=ir.BinaryOperator.BINARY_OP_ADD, left=list_var, right=list_literal
1222
+ ),
1223
+ span=_make_span(node),
1224
+ )
1225
+ assign = ir.Assignment(targets=[list_name], value=concat_expr)
1226
+ stmt.assignment.CopyFrom(assign)
1227
+ return stmt
1228
+
1229
+ # Regular expression
1230
+ expr = _expr_to_ir(node.value)
1231
+ if expr:
1232
+ stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=expr))
1233
+ return stmt
1234
+
1235
+ return None
1236
+
1237
+ def _visit_for(self, node: ast.For) -> List[ir.Statement]:
1238
+ """Convert for loop to IR.
1239
+
1240
+ The loop body is emitted as a full block so it can contain multiple
1241
+ statements/calls and early `return`.
1242
+ """
1243
+ # Get loop variables
1244
+ loop_vars: List[str] = []
1245
+ if isinstance(node.target, ast.Name):
1246
+ loop_vars.append(node.target.id)
1247
+ elif isinstance(node.target, ast.Tuple):
1248
+ for elt in node.target.elts:
1249
+ if isinstance(elt, ast.Name):
1250
+ loop_vars.append(elt.id)
1251
+
1252
+ # Get iterable
1253
+ iterable = _expr_to_ir(node.iter)
1254
+ if not iterable:
1255
+ return []
1256
+
1257
+ # Build body statements (recursively transforms nested structures)
1258
+ body_stmts: List[ir.Statement] = []
1259
+ for body_node in node.body:
1260
+ stmts = self._visit_statement(body_node)
1261
+ body_stmts.extend(stmts)
1262
+
1263
+ stmt = ir.Statement(span=_make_span(node))
1264
+ for_loop = ir.ForLoop(
1265
+ loop_vars=loop_vars,
1266
+ iterable=iterable,
1267
+ block_body=ir.Block(statements=body_stmts, span=_make_span(node)),
1268
+ )
1269
+ stmt.for_loop.CopyFrom(for_loop)
1270
+ return [stmt]
1271
+
1272
+ def _detect_accumulator_targets(
1273
+ self, stmts: List[ir.Statement], in_scope_vars: set
1274
+ ) -> List[str]:
1275
+ """Detect out-of-scope variable modifications in for loop body.
1276
+
1277
+ Scans statements for patterns that modify variables defined outside the loop.
1278
+ Returns a list of accumulator variable names that should be set as targets.
1279
+
1280
+ Supported patterns:
1281
+ 1. List append: results.append(value) -> "results"
1282
+ 2. Dict subscript: result[key] = value -> "result"
1283
+ 3. List/set update methods: results.extend(...), results.add(...) -> "results"
1284
+
1285
+ Note: Patterns like `results = results + [x]` and `count = count + 1` create
1286
+ new assignments which are tracked via in_scope_vars and don't need special
1287
+ detection here - they're handled by the regular assignment target logic.
1288
+ """
1289
+ accumulators: List[str] = []
1290
+ seen: set = set()
1291
+
1292
+ for stmt in stmts:
1293
+ var_name = self._extract_accumulator_from_stmt(stmt, in_scope_vars)
1294
+ if var_name and var_name not in seen:
1295
+ accumulators.append(var_name)
1296
+ seen.add(var_name)
1297
+
1298
+ # Check conditionals for accumulator targets in branch bodies
1299
+ if stmt.HasField("conditional"):
1300
+ cond = stmt.conditional
1301
+ branch_blocks: list[ir.Block] = []
1302
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1303
+ branch_blocks.append(cond.if_branch.block_body)
1304
+ for branch in cond.elif_branches:
1305
+ if branch.HasField("block_body"):
1306
+ branch_blocks.append(branch.block_body)
1307
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1308
+ branch_blocks.append(cond.else_branch.block_body)
1309
+
1310
+ for block in branch_blocks:
1311
+ for var in self._detect_accumulator_targets(
1312
+ list(block.statements), in_scope_vars
1313
+ ):
1314
+ if var not in seen:
1315
+ accumulators.append(var)
1316
+ seen.add(var)
1317
+
1318
+ return accumulators
1319
+
1320
+ def _extract_accumulator_from_stmt(
1321
+ self, stmt: ir.Statement, in_scope_vars: set
1322
+ ) -> Optional[str]:
1323
+ """Extract accumulator variable name from a single statement.
1324
+
1325
+ Returns the variable name if this statement modifies an out-of-scope variable,
1326
+ None otherwise.
1327
+ """
1328
+ # Pattern 1: Method calls like list.append(), dict.update(), set.add()
1329
+ if stmt.HasField("expr_stmt"):
1330
+ expr = stmt.expr_stmt.expr
1331
+ if expr.HasField("function_call"):
1332
+ fn_name = expr.function_call.name
1333
+ # Check for mutating method calls: x.append, x.extend, x.add, x.update, etc.
1334
+ mutating_methods = {
1335
+ ".append",
1336
+ ".extend",
1337
+ ".add",
1338
+ ".update",
1339
+ ".insert",
1340
+ ".pop",
1341
+ ".remove",
1342
+ ".clear",
1343
+ }
1344
+ for method in mutating_methods:
1345
+ if fn_name.endswith(method):
1346
+ var_name = fn_name[: len(fn_name) - len(method)]
1347
+ # Only return if it's an out-of-scope variable
1348
+ if var_name and var_name not in in_scope_vars:
1349
+ return var_name
1350
+
1351
+ # Pattern 2: Subscript assignment like dict[key] = value
1352
+ if stmt.HasField("assignment"):
1353
+ for target in stmt.assignment.targets:
1354
+ # Check if target is a subscript pattern (contains '[')
1355
+ if "[" in target:
1356
+ # Extract base variable name (before '[')
1357
+ var_name = target.split("[")[0]
1358
+ if var_name and var_name not in in_scope_vars:
1359
+ return var_name
1360
+
1361
+ # Pattern 3: Self-referential assignment like x = x + [y]
1362
+ # The target variable is used on the RHS, so it must come from outside.
1363
+ # Note: We don't check in_scope_vars here because the assignment itself
1364
+ # would have added the target to in_scope_vars, but it still needs its
1365
+ # previous value from outside the loop body.
1366
+ if stmt.HasField("assignment"):
1367
+ assign = stmt.assignment
1368
+ rhs_vars = self._collect_variables_from_expr(assign.value)
1369
+ for target in assign.targets:
1370
+ if target in rhs_vars:
1371
+ return target
1372
+
1373
+ return None
1374
+
1375
+ def _collect_variables_from_expr(self, expr: ir.Expr) -> set:
1376
+ """Recursively collect all variable names used in an expression."""
1377
+ vars_found: set = set()
1378
+
1379
+ if expr.HasField("variable"):
1380
+ vars_found.add(expr.variable.name)
1381
+ elif expr.HasField("binary_op"):
1382
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.left))
1383
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.right))
1384
+ elif expr.HasField("unary_op"):
1385
+ vars_found.update(self._collect_variables_from_expr(expr.unary_op.operand))
1386
+ elif expr.HasField("list"):
1387
+ for elem in expr.list.elements:
1388
+ vars_found.update(self._collect_variables_from_expr(elem))
1389
+ elif expr.HasField("dict"):
1390
+ for key in expr.dict.keys:
1391
+ vars_found.update(self._collect_variables_from_expr(key))
1392
+ for val in expr.dict.values:
1393
+ vars_found.update(self._collect_variables_from_expr(val))
1394
+ elif expr.HasField("index"):
1395
+ vars_found.update(self._collect_variables_from_expr(expr.index.value))
1396
+ vars_found.update(self._collect_variables_from_expr(expr.index.index))
1397
+ elif expr.HasField("dot"):
1398
+ vars_found.update(self._collect_variables_from_expr(expr.dot.object))
1399
+ elif expr.HasField("function_call"):
1400
+ for kwarg in expr.function_call.kwargs:
1401
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1402
+ elif expr.HasField("action_call"):
1403
+ for kwarg in expr.action_call.kwargs:
1404
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1405
+
1406
+ return vars_found
1407
+
1408
+ def _visit_if(self, node: ast.If) -> List[ir.Statement]:
1409
+ """Convert if statement to IR.
1410
+
1411
+ Normalizes patterns like:
1412
+ if await some_action(...):
1413
+ ...
1414
+ into:
1415
+ __if_cond_n__ = await some_action(...)
1416
+ if __if_cond_n__:
1417
+ ...
1418
+ """
1419
+
1420
+ def normalize_condition(test: ast.expr) -> tuple[List[ir.Statement], Optional[ir.Expr]]:
1421
+ action_call = self._extract_action_call(test)
1422
+ if action_call is None:
1423
+ return ([], _expr_to_ir(test))
1424
+
1425
+ if not isinstance(test, ast.Await):
1426
+ line = getattr(test, "lineno", None)
1427
+ col = getattr(test, "col_offset", None)
1428
+ raise UnsupportedPatternError(
1429
+ "Action calls inside boolean expressions are not supported in if conditions",
1430
+ "Assign the awaited action result to a variable, then use the variable in the if condition.",
1431
+ line=line,
1432
+ col=col,
1433
+ )
1434
+
1435
+ cond_var = self._ctx.next_implicit_fn_name(prefix="if_cond")
1436
+ assign_stmt = ir.Statement(span=_make_span(test))
1437
+ assign_stmt.assignment.CopyFrom(
1438
+ ir.Assignment(
1439
+ targets=[cond_var],
1440
+ value=ir.Expr(action_call=action_call, span=_make_span(test)),
1441
+ )
1442
+ )
1443
+ cond_expr = ir.Expr(variable=ir.Variable(name=cond_var), span=_make_span(test))
1444
+ return ([assign_stmt], cond_expr)
1445
+
1446
+ def visit_body(nodes: list[ast.stmt]) -> List[ir.Statement]:
1447
+ stmts: List[ir.Statement] = []
1448
+ for body_node in nodes:
1449
+ stmts.extend(self._visit_statement(body_node))
1450
+ return stmts
1451
+
1452
+ # Collect if/elif branches as (test_expr, body_nodes)
1453
+ branches: list[tuple[ast.expr, list[ast.stmt], ast.AST]] = [(node.test, node.body, node)]
1454
+ current = node
1455
+ while current.orelse and len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If):
1456
+ elif_node = current.orelse[0]
1457
+ branches.append((elif_node.test, elif_node.body, elif_node))
1458
+ current = elif_node
1459
+
1460
+ else_nodes = current.orelse
1461
+
1462
+ normalized: list[
1463
+ tuple[List[ir.Statement], Optional[ir.Expr], List[ir.Statement], ast.AST]
1464
+ ] = []
1465
+ for test_expr, body_nodes, span_node in branches:
1466
+ prefix, cond = normalize_condition(test_expr)
1467
+ normalized.append((prefix, cond, visit_body(body_nodes), span_node))
1468
+
1469
+ else_body = visit_body(else_nodes) if else_nodes else []
1470
+
1471
+ # If any non-first branch needs normalization, preserve Python semantics by nesting.
1472
+ requires_nested = any(prefix for prefix, _, _, _ in normalized[1:])
1473
+
1474
+ def build_conditional_stmt(
1475
+ condition: ir.Expr,
1476
+ then_body: List[ir.Statement],
1477
+ else_body_statements: List[ir.Statement],
1478
+ span_node: ast.AST,
1479
+ ) -> ir.Statement:
1480
+ conditional_stmt = ir.Statement(span=_make_span(span_node))
1481
+ if_branch = ir.IfBranch(
1482
+ condition=condition,
1483
+ block_body=ir.Block(statements=then_body, span=_make_span(span_node)),
1484
+ span=_make_span(span_node),
1485
+ )
1486
+ conditional = ir.Conditional(if_branch=if_branch)
1487
+ if else_body_statements:
1488
+ else_branch = ir.ElseBranch(
1489
+ block_body=ir.Block(
1490
+ statements=else_body_statements,
1491
+ span=_make_span(span_node),
1492
+ ),
1493
+ span=_make_span(span_node),
1494
+ )
1495
+ conditional.else_branch.CopyFrom(else_branch)
1496
+ conditional_stmt.conditional.CopyFrom(conditional)
1497
+ return conditional_stmt
1498
+
1499
+ if requires_nested:
1500
+ nested_else: List[ir.Statement] = else_body
1501
+ for prefix, cond, then_body, span_node in reversed(normalized):
1502
+ if cond is None:
1503
+ continue
1504
+ nested_if_stmt = build_conditional_stmt(
1505
+ condition=cond,
1506
+ then_body=then_body,
1507
+ else_body_statements=nested_else,
1508
+ span_node=span_node,
1509
+ )
1510
+ nested_else = [*prefix, nested_if_stmt]
1511
+ return nested_else
1512
+
1513
+ # Flat conditional with elif/else (original behavior), plus optional prefix for the if guard.
1514
+ if_prefix, if_condition, if_body, if_span_node = normalized[0]
1515
+ if if_condition is None:
1516
+ return []
1517
+
1518
+ conditional_stmt = ir.Statement(span=_make_span(if_span_node))
1519
+ if_branch = ir.IfBranch(
1520
+ condition=if_condition,
1521
+ block_body=ir.Block(statements=if_body, span=_make_span(if_span_node)),
1522
+ span=_make_span(if_span_node),
1523
+ )
1524
+ conditional = ir.Conditional(if_branch=if_branch)
1525
+
1526
+ for _, elif_condition, elif_body, elif_span_node in normalized[1:]:
1527
+ if elif_condition is None:
1528
+ continue
1529
+ elif_branch = ir.ElifBranch(
1530
+ condition=elif_condition,
1531
+ block_body=ir.Block(statements=elif_body, span=_make_span(elif_span_node)),
1532
+ span=_make_span(elif_span_node),
1533
+ )
1534
+ conditional.elif_branches.append(elif_branch)
1535
+
1536
+ if else_body:
1537
+ else_branch = ir.ElseBranch(
1538
+ block_body=ir.Block(statements=else_body, span=_make_span(if_span_node)),
1539
+ span=_make_span(if_span_node),
1540
+ )
1541
+ conditional.else_branch.CopyFrom(else_branch)
1542
+
1543
+ conditional_stmt.conditional.CopyFrom(conditional)
1544
+ return [*if_prefix, conditional_stmt]
1545
+
1546
+ def _collect_assigned_vars(self, stmts: List[ir.Statement]) -> set:
1547
+ """Collect all variable names assigned in a list of statements."""
1548
+ assigned = set()
1549
+ for stmt in stmts:
1550
+ if stmt.HasField("assignment"):
1551
+ assigned.update(stmt.assignment.targets)
1552
+ return assigned
1553
+
1554
+ def _collect_assigned_vars_in_order(self, stmts: List[ir.Statement]) -> list[str]:
1555
+ """Collect assigned variable names in statement order (deduplicated)."""
1556
+ assigned: list[str] = []
1557
+ seen: set[str] = set()
1558
+
1559
+ for stmt in stmts:
1560
+ if stmt.HasField("assignment"):
1561
+ for target in stmt.assignment.targets:
1562
+ if target not in seen:
1563
+ seen.add(target)
1564
+ assigned.append(target)
1565
+
1566
+ if stmt.HasField("conditional"):
1567
+ cond = stmt.conditional
1568
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1569
+ for target in self._collect_assigned_vars_in_order(
1570
+ list(cond.if_branch.block_body.statements)
1571
+ ):
1572
+ if target not in seen:
1573
+ seen.add(target)
1574
+ assigned.append(target)
1575
+ for elif_branch in cond.elif_branches:
1576
+ if elif_branch.HasField("block_body"):
1577
+ for target in self._collect_assigned_vars_in_order(
1578
+ list(elif_branch.block_body.statements)
1579
+ ):
1580
+ if target not in seen:
1581
+ seen.add(target)
1582
+ assigned.append(target)
1583
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1584
+ for target in self._collect_assigned_vars_in_order(
1585
+ list(cond.else_branch.block_body.statements)
1586
+ ):
1587
+ if target not in seen:
1588
+ seen.add(target)
1589
+ assigned.append(target)
1590
+
1591
+ if stmt.HasField("for_loop") and stmt.for_loop.HasField("block_body"):
1592
+ for target in self._collect_assigned_vars_in_order(
1593
+ list(stmt.for_loop.block_body.statements)
1594
+ ):
1595
+ if target not in seen:
1596
+ seen.add(target)
1597
+ assigned.append(target)
1598
+
1599
+ if stmt.HasField("try_except"):
1600
+ try_block = stmt.try_except.try_block
1601
+ if try_block.HasField("span"):
1602
+ for target in self._collect_assigned_vars_in_order(list(try_block.statements)):
1603
+ if target not in seen:
1604
+ seen.add(target)
1605
+ assigned.append(target)
1606
+ for handler in stmt.try_except.handlers:
1607
+ if handler.HasField("block_body"):
1608
+ for target in self._collect_assigned_vars_in_order(
1609
+ list(handler.block_body.statements)
1610
+ ):
1611
+ if target not in seen:
1612
+ seen.add(target)
1613
+ assigned.append(target)
1614
+
1615
+ return assigned
1616
+
1617
+ def _collect_variables_from_block(self, block: ir.Block) -> list[str]:
1618
+ return self._collect_variables_from_statements(list(block.statements))
1619
+
1620
+ def _collect_variables_from_statements(self, stmts: List[ir.Statement]) -> list[str]:
1621
+ """Collect variable references from statements in encounter order."""
1622
+ vars_found: list[str] = []
1623
+ seen: set[str] = set()
1624
+
1625
+ for stmt in stmts:
1626
+ if stmt.HasField("assignment") and stmt.assignment.HasField("value"):
1627
+ for var in self._collect_variables_from_expr(stmt.assignment.value):
1628
+ if var not in seen:
1629
+ seen.add(var)
1630
+ vars_found.append(var)
1631
+
1632
+ if stmt.HasField("return_stmt") and stmt.return_stmt.HasField("value"):
1633
+ for var in self._collect_variables_from_expr(stmt.return_stmt.value):
1634
+ if var not in seen:
1635
+ seen.add(var)
1636
+ vars_found.append(var)
1637
+
1638
+ if stmt.HasField("action_call"):
1639
+ expr = ir.Expr(action_call=stmt.action_call, span=stmt.span)
1640
+ for var in self._collect_variables_from_expr(expr):
1641
+ if var not in seen:
1642
+ seen.add(var)
1643
+ vars_found.append(var)
1644
+
1645
+ if stmt.HasField("expr_stmt"):
1646
+ for var in self._collect_variables_from_expr(stmt.expr_stmt.expr):
1647
+ if var not in seen:
1648
+ seen.add(var)
1649
+ vars_found.append(var)
1650
+
1651
+ if stmt.HasField("conditional"):
1652
+ cond = stmt.conditional
1653
+ if cond.HasField("if_branch"):
1654
+ if cond.if_branch.HasField("condition"):
1655
+ for var in self._collect_variables_from_expr(cond.if_branch.condition):
1656
+ if var not in seen:
1657
+ seen.add(var)
1658
+ vars_found.append(var)
1659
+ if cond.if_branch.HasField("block_body"):
1660
+ for var in self._collect_variables_from_block(cond.if_branch.block_body):
1661
+ if var not in seen:
1662
+ seen.add(var)
1663
+ vars_found.append(var)
1664
+ for elif_branch in cond.elif_branches:
1665
+ if elif_branch.HasField("condition"):
1666
+ for var in self._collect_variables_from_expr(elif_branch.condition):
1667
+ if var not in seen:
1668
+ seen.add(var)
1669
+ vars_found.append(var)
1670
+ if elif_branch.HasField("block_body"):
1671
+ for var in self._collect_variables_from_block(elif_branch.block_body):
1672
+ if var not in seen:
1673
+ seen.add(var)
1674
+ vars_found.append(var)
1675
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1676
+ for var in self._collect_variables_from_block(cond.else_branch.block_body):
1677
+ if var not in seen:
1678
+ seen.add(var)
1679
+ vars_found.append(var)
1680
+
1681
+ if stmt.HasField("for_loop"):
1682
+ fl = stmt.for_loop
1683
+ if fl.HasField("iterable"):
1684
+ for var in self._collect_variables_from_expr(fl.iterable):
1685
+ if var not in seen:
1686
+ seen.add(var)
1687
+ vars_found.append(var)
1688
+ if fl.HasField("block_body"):
1689
+ for var in self._collect_variables_from_block(fl.block_body):
1690
+ if var not in seen:
1691
+ seen.add(var)
1692
+ vars_found.append(var)
1693
+
1694
+ if stmt.HasField("try_except"):
1695
+ te = stmt.try_except
1696
+ if te.HasField("try_block"):
1697
+ for var in self._collect_variables_from_block(te.try_block):
1698
+ if var not in seen:
1699
+ seen.add(var)
1700
+ vars_found.append(var)
1701
+ for handler in te.handlers:
1702
+ if handler.HasField("block_body"):
1703
+ for var in self._collect_variables_from_block(handler.block_body):
1704
+ if var not in seen:
1705
+ seen.add(var)
1706
+ vars_found.append(var)
1707
+
1708
+ if stmt.HasField("parallel_block"):
1709
+ for call in stmt.parallel_block.calls:
1710
+ if call.HasField("action"):
1711
+ for kwarg in call.action.kwargs:
1712
+ for var in self._collect_variables_from_expr(kwarg.value):
1713
+ if var not in seen:
1714
+ seen.add(var)
1715
+ vars_found.append(var)
1716
+ elif call.HasField("function"):
1717
+ for kwarg in call.function.kwargs:
1718
+ for var in self._collect_variables_from_expr(kwarg.value):
1719
+ if var not in seen:
1720
+ seen.add(var)
1721
+ vars_found.append(var)
1722
+
1723
+ if stmt.HasField("spread_action"):
1724
+ spread = stmt.spread_action
1725
+ if spread.HasField("collection"):
1726
+ for var in self._collect_variables_from_expr(spread.collection):
1727
+ if var not in seen:
1728
+ seen.add(var)
1729
+ vars_found.append(var)
1730
+ if spread.HasField("action"):
1731
+ for kwarg in spread.action.kwargs:
1732
+ for var in self._collect_variables_from_expr(kwarg.value):
1733
+ if var not in seen:
1734
+ seen.add(var)
1735
+ vars_found.append(var)
1736
+
1737
+ return vars_found
1738
+
1739
+ def _visit_try(self, node: ast.Try) -> List[ir.Statement]:
1740
+ """Convert try/except to IR with full block bodies."""
1741
+ # Build try body statements (recursively transforms nested structures)
1742
+ try_body: List[ir.Statement] = []
1743
+ for body_node in node.body:
1744
+ stmts = self._visit_statement(body_node)
1745
+ try_body.extend(stmts)
1746
+
1747
+ # Build exception handlers (with wrapping if needed)
1748
+ handlers: List[ir.ExceptHandler] = []
1749
+ for handler in node.handlers:
1750
+ exception_types: List[str] = []
1751
+ if handler.type:
1752
+ if isinstance(handler.type, ast.Name):
1753
+ exception_types.append(handler.type.id)
1754
+ elif isinstance(handler.type, ast.Tuple):
1755
+ for elt in handler.type.elts:
1756
+ if isinstance(elt, ast.Name):
1757
+ exception_types.append(elt.id)
1758
+
1759
+ # Build handler body (recursively transforms nested structures)
1760
+ handler_body: List[ir.Statement] = []
1761
+ for handler_node in handler.body:
1762
+ stmts = self._visit_statement(handler_node)
1763
+ handler_body.extend(stmts)
1764
+
1765
+ except_handler = ir.ExceptHandler(
1766
+ exception_types=exception_types,
1767
+ block_body=ir.Block(statements=handler_body, span=_make_span(handler)),
1768
+ span=_make_span(handler),
1769
+ )
1770
+ handlers.append(except_handler)
1771
+
1772
+ # Build the try/except statement
1773
+ try_stmt = ir.Statement(span=_make_span(node))
1774
+ try_except = ir.TryExcept(
1775
+ handlers=handlers,
1776
+ try_block=ir.Block(statements=try_body, span=_make_span(node)),
1777
+ )
1778
+ try_stmt.try_except.CopyFrom(try_except)
1779
+
1780
+ return [try_stmt]
1781
+
1782
+ def _count_calls(self, stmts: List[ir.Statement]) -> int:
1783
+ """Count action calls and function calls in statements.
1784
+
1785
+ Both action calls and function calls (including synthetic functions)
1786
+ count toward the limit of one call per control flow body.
1787
+ """
1788
+ count = 0
1789
+ for stmt in stmts:
1790
+ if stmt.HasField("action_call"):
1791
+ count += 1
1792
+ elif stmt.HasField("assignment"):
1793
+ # Check if assignment value is an action call or function call
1794
+ if stmt.assignment.value.HasField("action_call"):
1795
+ count += 1
1796
+ elif stmt.assignment.value.HasField("function_call"):
1797
+ count += 1
1798
+ elif stmt.HasField("expr_stmt"):
1799
+ # Check if expression is a function call
1800
+ if stmt.expr_stmt.expr.HasField("function_call"):
1801
+ count += 1
1802
+ return count
1803
+
1804
+ def _wrap_body_as_function(
1805
+ self,
1806
+ body: List[ir.Statement],
1807
+ prefix: str,
1808
+ node: ast.AST,
1809
+ inputs: Optional[List[str]] = None,
1810
+ modified_vars: Optional[List[str]] = None,
1811
+ ) -> List[ir.Statement]:
1812
+ """Wrap a body with multiple calls into a synthetic function.
1813
+
1814
+ Args:
1815
+ body: The statements to wrap
1816
+ prefix: Name prefix for the synthetic function
1817
+ node: AST node for span information
1818
+ inputs: Variables to pass as inputs (e.g., loop variables)
1819
+ modified_vars: Out-of-scope variables modified in the body.
1820
+ These are added as inputs AND returned as outputs,
1821
+ enabling functional transformation of external state.
1822
+
1823
+ Returns a list containing a single function call statement (or assignment
1824
+ if modified_vars are present).
1825
+ """
1826
+ fn_name = self._ctx.next_implicit_fn_name(prefix)
1827
+ fn_inputs = list(inputs or [])
1828
+
1829
+ # Add modified variables as inputs (they need to be passed in)
1830
+ modified_vars = modified_vars or []
1831
+ for var in modified_vars:
1832
+ if var not in fn_inputs:
1833
+ fn_inputs.append(var)
1834
+
1835
+ # If there are modified variables, add a return statement for them
1836
+ wrapped_body = list(body)
1837
+ if modified_vars:
1838
+ # Create return statement: return (var1, var2, ...) or return var1
1839
+ if len(modified_vars) == 1:
1840
+ return_expr = ir.Expr(
1841
+ variable=ir.Variable(name=modified_vars[0]),
1842
+ span=_make_span(node),
1843
+ )
1844
+ else:
1845
+ # Return as list (tuples are represented as lists in IR)
1846
+ return_expr = ir.Expr(
1847
+ list=ir.ListExpr(
1848
+ elements=[ir.Expr(variable=ir.Variable(name=var)) for var in modified_vars]
1849
+ ),
1850
+ span=_make_span(node),
1851
+ )
1852
+ return_stmt = ir.Statement(span=_make_span(node))
1853
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=return_expr))
1854
+ wrapped_body.append(return_stmt)
1855
+
1856
+ # Create the synthetic function
1857
+ implicit_fn = ir.FunctionDef(
1858
+ name=fn_name,
1859
+ io=ir.IoDecl(inputs=fn_inputs, outputs=modified_vars),
1860
+ body=ir.Block(statements=wrapped_body),
1861
+ span=_make_span(node),
1862
+ )
1863
+ self._ctx.implicit_functions.append(implicit_fn)
1864
+
1865
+ # Create a function call expression
1866
+ kwargs = [
1867
+ ir.Kwarg(name=var, value=ir.Expr(variable=ir.Variable(name=var))) for var in fn_inputs
1868
+ ]
1869
+ fn_call_expr = ir.Expr(
1870
+ function_call=ir.FunctionCall(name=fn_name, kwargs=kwargs),
1871
+ span=_make_span(node),
1872
+ )
1873
+
1874
+ # If there are modified variables, create an assignment statement
1875
+ # so the returned values are assigned back to the variables
1876
+ call_stmt = ir.Statement(span=_make_span(node))
1877
+ if modified_vars:
1878
+ # Create assignment: var1, var2 = fn(...) or var1 = fn(...)
1879
+ assign = ir.Assignment(value=fn_call_expr)
1880
+ assign.targets.extend(modified_vars)
1881
+ call_stmt.assignment.CopyFrom(assign)
1882
+ else:
1883
+ call_stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=fn_call_expr))
1884
+
1885
+ return [call_stmt]
1886
+
1887
+ def _visit_return(self, node: ast.Return) -> List[ir.Statement]:
1888
+ """Convert return statement to IR.
1889
+
1890
+ Return statements should only contain variables or literals, not action calls.
1891
+ If the return contains an action call, we normalize it:
1892
+ return await action()
1893
+ becomes:
1894
+ _return_tmp = await action()
1895
+ return _return_tmp
1896
+
1897
+ Constructor calls (like return MyModel(...)) are not supported and will
1898
+ raise an error with a recommendation to use an @action instead.
1899
+ """
1900
+ if node.value:
1901
+ # Check for constructor calls in return (e.g., return MyModel(...))
1902
+ self._check_constructor_in_return(node.value)
1903
+
1904
+ # Check if returning an action call - normalize to assignment + return
1905
+ action_call = self._extract_action_call(node.value)
1906
+ if action_call:
1907
+ # Create a temporary variable for the action result
1908
+ tmp_var = "_return_tmp"
1909
+
1910
+ # Create assignment: _return_tmp = await action()
1911
+ assign_stmt = ir.Statement(span=_make_span(node))
1912
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1913
+ assign = ir.Assignment(targets=[tmp_var], value=value)
1914
+ assign_stmt.assignment.CopyFrom(assign)
1915
+
1916
+ # Create return: return _return_tmp
1917
+ return_stmt = ir.Statement(span=_make_span(node))
1918
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
1919
+ ret = ir.ReturnStmt(value=var_expr)
1920
+ return_stmt.return_stmt.CopyFrom(ret)
1921
+
1922
+ return [assign_stmt, return_stmt]
1923
+
1924
+ # Normalize return of function calls into assignment + return
1925
+ expr = _expr_to_ir(node.value)
1926
+ if expr and expr.HasField("function_call"):
1927
+ tmp_var = self._ctx.next_implicit_fn_name(prefix="return_tmp")
1928
+
1929
+ assign_stmt = ir.Statement(span=_make_span(node))
1930
+ assign_stmt.assignment.CopyFrom(ir.Assignment(targets=[tmp_var], value=expr))
1931
+
1932
+ return_stmt = ir.Statement(span=_make_span(node))
1933
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
1934
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=var_expr))
1935
+ return [assign_stmt, return_stmt]
1936
+
1937
+ # Regular return with expression (variable, literal, etc.)
1938
+ if expr:
1939
+ stmt = ir.Statement(span=_make_span(node))
1940
+ return_stmt = ir.ReturnStmt(value=expr)
1941
+ stmt.return_stmt.CopyFrom(return_stmt)
1942
+ return [stmt]
1943
+
1944
+ # Return with no value
1945
+ stmt = ir.Statement(span=_make_span(node))
1946
+ stmt.return_stmt.CopyFrom(ir.ReturnStmt())
1947
+ return [stmt]
1948
+
1949
+ def _visit_aug_assign(self, node: ast.AugAssign) -> List[ir.Statement]:
1950
+ """Convert augmented assignment (+=, -=, etc.) to IR."""
1951
+ # For now, we can represent this as a regular assignment with binary op
1952
+ # target op= value -> target = target op value
1953
+ stmt = ir.Statement(span=_make_span(node))
1954
+
1955
+ targets: List[str] = []
1956
+ if isinstance(node.target, ast.Name):
1957
+ targets.append(node.target.id)
1958
+
1959
+ left = _expr_to_ir(node.target)
1960
+ right = _expr_to_ir(node.value)
1961
+ if right and right.HasField("function_call"):
1962
+ tmp_var = self._ctx.next_implicit_fn_name(prefix="aug_tmp")
1963
+
1964
+ assign_tmp = ir.Statement(span=_make_span(node))
1965
+ assign_tmp.assignment.CopyFrom(
1966
+ ir.Assignment(
1967
+ targets=[tmp_var],
1968
+ value=ir.Expr(function_call=right.function_call, span=_make_span(node)),
1969
+ )
1970
+ )
1971
+
1972
+ if left:
1973
+ op = _bin_op_to_ir(node.op)
1974
+ if op:
1975
+ binary = ir.BinaryOp(
1976
+ left=left,
1977
+ op=op,
1978
+ right=ir.Expr(variable=ir.Variable(name=tmp_var)),
1979
+ )
1980
+ value = ir.Expr(binary_op=binary)
1981
+ assign = ir.Assignment(targets=targets, value=value)
1982
+ stmt.assignment.CopyFrom(assign)
1983
+ return [assign_tmp, stmt]
1984
+ return [assign_tmp]
1985
+
1986
+ if left and right:
1987
+ op = _bin_op_to_ir(node.op)
1988
+ if op:
1989
+ binary = ir.BinaryOp(left=left, op=op, right=right)
1990
+ value = ir.Expr(binary_op=binary)
1991
+ assign = ir.Assignment(targets=targets, value=value)
1992
+ stmt.assignment.CopyFrom(assign)
1993
+ return [stmt]
1994
+
1995
+ return []
1996
+
1997
+ def _check_constructor_in_return(self, node: ast.expr) -> None:
1998
+ """Check for constructor calls in return statements.
1999
+
2000
+ Raises UnsupportedPatternError if the return value is a class instantiation
2001
+ like: return MyModel(field=value)
2002
+
2003
+ This is not supported because the workflow IR cannot serialize arbitrary
2004
+ object instantiation. Users should use an @action to create objects.
2005
+ """
2006
+ # Skip if it's an await (action call) - those are fine
2007
+ if isinstance(node, ast.Await):
2008
+ return
2009
+
2010
+ # Check for direct Call that looks like a constructor
2011
+ if isinstance(node, ast.Call):
2012
+ func_name = self._get_constructor_name(node.func)
2013
+ if func_name and self._looks_like_constructor(func_name, node):
2014
+ line = getattr(node, "lineno", None)
2015
+ col = getattr(node, "col_offset", None)
2016
+ raise UnsupportedPatternError(
2017
+ f"Returning constructor call '{func_name}(...)' is not supported",
2018
+ RECOMMENDATIONS["constructor_return"],
2019
+ line=line,
2020
+ col=col,
2021
+ )
2022
+
2023
+ def _check_constructor_in_assignment(self, node: ast.expr) -> None:
2024
+ """Check for constructor calls in assignments.
2025
+
2026
+ Raises UnsupportedPatternError if the assignment value is a class instantiation
2027
+ like: result = MyModel(field=value)
2028
+
2029
+ This is not supported because the workflow IR cannot serialize arbitrary
2030
+ object instantiation. Users should use an @action to create objects.
2031
+ """
2032
+ # Skip if it's an await (action call) - those are fine
2033
+ if isinstance(node, ast.Await):
2034
+ return
2035
+
2036
+ # Check for direct Call that looks like a constructor
2037
+ if isinstance(node, ast.Call):
2038
+ func_name = self._get_constructor_name(node.func)
2039
+ if func_name and self._looks_like_constructor(func_name, node):
2040
+ line = getattr(node, "lineno", None)
2041
+ col = getattr(node, "col_offset", None)
2042
+ raise UnsupportedPatternError(
2043
+ f"Assigning constructor call '{func_name}(...)' is not supported",
2044
+ RECOMMENDATIONS["constructor_assignment"],
2045
+ line=line,
2046
+ col=col,
2047
+ )
2048
+
2049
+ def _get_constructor_name(self, func: ast.expr) -> Optional[str]:
2050
+ """Get the name from a function expression if it looks like a constructor."""
2051
+ if isinstance(func, ast.Name):
2052
+ return func.id
2053
+ elif isinstance(func, ast.Attribute):
2054
+ return func.attr
2055
+ return None
2056
+
2057
+ def _looks_like_constructor(self, func_name: str, call: ast.Call) -> bool:
2058
+ """Check if a function call looks like a class constructor.
2059
+
2060
+ A constructor is identified by:
2061
+ 1. Name starts with uppercase (PEP8 convention for classes)
2062
+ 2. It's not a known action
2063
+ 3. It's not a known builtin like String operations
2064
+ 4. It's not a known Pydantic model or dataclass (those are allowed)
2065
+
2066
+ This is a heuristic - we can't perfectly distinguish constructors
2067
+ from functions without full type information.
2068
+ """
2069
+ # Check if first letter is uppercase (class naming convention)
2070
+ if not func_name or not func_name[0].isupper():
2071
+ return False
2072
+
2073
+ # If it's a known action, it's not a constructor
2074
+ if func_name in self._action_defs:
2075
+ return False
2076
+
2077
+ # If it's a known Pydantic model or dataclass, allow it
2078
+ # (it will be converted to a dict expression)
2079
+ if func_name in self._model_defs:
2080
+ return False
2081
+
2082
+ # Common builtins that start with uppercase but aren't constructors
2083
+ # (these are rarely used in workflow code but let's be safe)
2084
+ builtin_exceptions = {"True", "False", "None", "Ellipsis"}
2085
+ if func_name in builtin_exceptions:
2086
+ return False
2087
+
2088
+ return True
2089
+
2090
+ def _is_model_constructor(self, node: ast.expr) -> Optional[str]:
2091
+ """Check if an expression is a Pydantic model or dataclass constructor call.
2092
+
2093
+ Returns the model name if it is, None otherwise.
2094
+ """
2095
+ if not isinstance(node, ast.Call):
2096
+ return None
2097
+
2098
+ func_name = self._get_constructor_name(node.func)
2099
+ if func_name and func_name in self._model_defs:
2100
+ return func_name
2101
+
2102
+ return None
2103
+
2104
+ def _convert_model_constructor_to_dict(self, node: ast.Call, model_name: str) -> ir.Expr:
2105
+ """Convert a Pydantic model or dataclass constructor call to a dict expression.
2106
+
2107
+ For example:
2108
+ MyModel(field1=value1, field2=value2)
2109
+ becomes:
2110
+ {"field1": value1, "field2": value2}
2111
+
2112
+ Default values from the model definition are included for fields not
2113
+ explicitly provided in the constructor call.
2114
+ """
2115
+ model_def = self._model_defs[model_name]
2116
+ entries: List[ir.DictEntry] = []
2117
+
2118
+ # Track which fields were explicitly provided
2119
+ provided_fields: Set[str] = set()
2120
+
2121
+ # First, add all explicitly provided kwargs
2122
+ for kw in node.keywords:
2123
+ if kw.arg is None:
2124
+ # **kwargs expansion - not supported
2125
+ line = getattr(node, "lineno", None)
2126
+ col = getattr(node, "col_offset", None)
2127
+ raise UnsupportedPatternError(
2128
+ f"Model constructor '{model_name}' with **kwargs is not supported",
2129
+ "Use explicit keyword arguments instead of **kwargs.",
2130
+ line=line,
2131
+ col=col,
2132
+ )
2133
+
2134
+ provided_fields.add(kw.arg)
2135
+ key_expr = ir.Expr()
2136
+ key_literal = ir.Literal()
2137
+ key_literal.string_value = kw.arg
2138
+ key_expr.literal.CopyFrom(key_literal)
2139
+
2140
+ value_expr = _expr_to_ir(kw.value)
2141
+ if value_expr is None:
2142
+ # If we can't convert the value, we need to raise an error
2143
+ line = getattr(node, "lineno", None)
2144
+ col = getattr(node, "col_offset", None)
2145
+ raise UnsupportedPatternError(
2146
+ f"Cannot convert value for field '{kw.arg}' in '{model_name}'",
2147
+ "Use simpler expressions (literals, variables, dicts, lists).",
2148
+ line=line,
2149
+ col=col,
2150
+ )
2151
+
2152
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2153
+
2154
+ # Handle positional arguments (dataclasses support this)
2155
+ if node.args:
2156
+ # For dataclasses, positional args map to fields in order
2157
+ field_names = list(model_def.fields.keys())
2158
+ for i, arg in enumerate(node.args):
2159
+ if i >= len(field_names):
2160
+ line = getattr(node, "lineno", None)
2161
+ col = getattr(node, "col_offset", None)
2162
+ raise UnsupportedPatternError(
2163
+ f"Too many positional arguments for '{model_name}'",
2164
+ "Use keyword arguments for clarity.",
2165
+ line=line,
2166
+ col=col,
2167
+ )
2168
+
2169
+ field_name = field_names[i]
2170
+ provided_fields.add(field_name)
2171
+
2172
+ key_expr = ir.Expr()
2173
+ key_literal = ir.Literal()
2174
+ key_literal.string_value = field_name
2175
+ key_expr.literal.CopyFrom(key_literal)
2176
+
2177
+ value_expr = _expr_to_ir(arg)
2178
+ if value_expr is None:
2179
+ line = getattr(node, "lineno", None)
2180
+ col = getattr(node, "col_offset", None)
2181
+ raise UnsupportedPatternError(
2182
+ f"Cannot convert positional argument for field '{field_name}' in '{model_name}'",
2183
+ "Use simpler expressions (literals, variables, dicts, lists).",
2184
+ line=line,
2185
+ col=col,
2186
+ )
2187
+
2188
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2189
+
2190
+ # Add default values for fields not explicitly provided
2191
+ for field_name, field_def in model_def.fields.items():
2192
+ if field_name in provided_fields:
2193
+ continue
2194
+
2195
+ if field_def.has_default:
2196
+ key_expr = ir.Expr()
2197
+ key_literal = ir.Literal()
2198
+ key_literal.string_value = field_name
2199
+ key_expr.literal.CopyFrom(key_literal)
2200
+
2201
+ # Convert the default value to an IR literal
2202
+ default_literal = _constant_to_literal(field_def.default_value)
2203
+ if default_literal is None:
2204
+ # Can't serialize this default - skip it
2205
+ # (it's probably a complex object like a list factory)
2206
+ continue
2207
+
2208
+ value_expr = ir.Expr()
2209
+ value_expr.literal.CopyFrom(default_literal)
2210
+
2211
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2212
+
2213
+ result = ir.Expr(span=_make_span(node))
2214
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2215
+ return result
2216
+
2217
+ def _check_non_action_await(self, node: ast.Await) -> None:
2218
+ """Check if an await is for a non-action function.
2219
+
2220
+ Note: We can only reliably detect non-action awaits for functions defined
2221
+ in the same module. Actions imported from other modules will pass through
2222
+ and may fail at runtime if they're not actually actions.
2223
+
2224
+ For now, we only check against common builtins and known non-action patterns.
2225
+ A runtime check will catch functions that aren't registered actions.
2226
+ """
2227
+ awaited = node.value
2228
+ if not isinstance(awaited, ast.Call):
2229
+ return
2230
+
2231
+ # Skip special cases that are handled elsewhere
2232
+ if self._is_run_action_call(awaited):
2233
+ return
2234
+ if self._is_asyncio_sleep_call(awaited):
2235
+ return
2236
+ if self._is_asyncio_gather_call(awaited):
2237
+ return
2238
+
2239
+ # Get the function name
2240
+ func_name = None
2241
+ if isinstance(awaited.func, ast.Name):
2242
+ func_name = awaited.func.id
2243
+ elif isinstance(awaited.func, ast.Attribute):
2244
+ func_name = awaited.func.attr
2245
+
2246
+ if not func_name:
2247
+ return
2248
+
2249
+ # Only raise error for functions defined in THIS module that we know
2250
+ # are NOT actions (i.e., async functions without @action decorator)
2251
+ # We can't reliably detect imported non-actions without full type info.
2252
+ #
2253
+ # The check works by looking at _module_functions which contains
2254
+ # functions defined in the same module as the workflow.
2255
+ if func_name in getattr(self, "_module_functions", set()):
2256
+ if func_name not in self._action_defs:
2257
+ line = getattr(node, "lineno", None)
2258
+ col = getattr(node, "col_offset", None)
2259
+ raise UnsupportedPatternError(
2260
+ f"Awaiting non-action function '{func_name}()' is not supported",
2261
+ RECOMMENDATIONS["non_action_call"],
2262
+ line=line,
2263
+ col=col,
2264
+ )
2265
+
2266
+ def _check_sync_function_call(self, node: ast.Call) -> None:
2267
+ """Check for synchronous function calls that should be in actions.
2268
+
2269
+ Common patterns like len(), str(), etc. are not supported in workflow code.
2270
+ """
2271
+ func_name = None
2272
+ if isinstance(node.func, ast.Name):
2273
+ func_name = node.func.id
2274
+ elif isinstance(node.func, ast.Attribute):
2275
+ # Method calls on objects - check the method name
2276
+ func_name = node.func.attr
2277
+
2278
+ if not func_name:
2279
+ return
2280
+
2281
+ # Builtins that users commonly try to use
2282
+ common_builtins = {
2283
+ "len",
2284
+ "str",
2285
+ "int",
2286
+ "float",
2287
+ "bool",
2288
+ "list",
2289
+ "dict",
2290
+ "set",
2291
+ "tuple",
2292
+ "sum",
2293
+ "min",
2294
+ "max",
2295
+ "sorted",
2296
+ "reversed",
2297
+ "enumerate",
2298
+ "zip",
2299
+ "map",
2300
+ "filter",
2301
+ "range",
2302
+ "abs",
2303
+ "round",
2304
+ "print",
2305
+ "type",
2306
+ "isinstance",
2307
+ "hasattr",
2308
+ "getattr",
2309
+ "setattr",
2310
+ "open",
2311
+ "format",
2312
+ }
2313
+
2314
+ if func_name in common_builtins:
2315
+ line = getattr(node, "lineno", None)
2316
+ col = getattr(node, "col_offset", None)
2317
+ raise UnsupportedPatternError(
2318
+ f"Calling built-in function '{func_name}()' directly is not supported",
2319
+ RECOMMENDATIONS["builtin_call"],
2320
+ line=line,
2321
+ col=col,
2322
+ )
2323
+
2324
+ def _extract_action_call(self, node: ast.expr) -> Optional[ir.ActionCall]:
2325
+ """Extract an action call from an expression if present.
2326
+
2327
+ Also validates that awaited calls are actually @action decorated functions.
2328
+ Raises UnsupportedPatternError if awaiting a non-action function.
2329
+ """
2330
+ if not isinstance(node, ast.Await):
2331
+ return None
2332
+
2333
+ awaited = node.value
2334
+ # Handle self.run_action(...) wrapper
2335
+ if isinstance(awaited, ast.Call):
2336
+ if self._is_run_action_call(awaited):
2337
+ # Extract the actual action call from run_action
2338
+ if awaited.args:
2339
+ action_call = self._extract_action_call_from_awaitable(awaited.args[0])
2340
+ if action_call:
2341
+ # Extract policies from run_action kwargs (retry, timeout)
2342
+ self._extract_policies_from_run_action(awaited, action_call)
2343
+ return action_call
2344
+ # Check for asyncio.sleep() - convert to @sleep action
2345
+ if self._is_asyncio_sleep_call(awaited):
2346
+ return self._convert_asyncio_sleep_to_action(awaited)
2347
+ # Try to extract as action call
2348
+ action_call = self._extract_action_call_from_call(awaited)
2349
+ if action_call:
2350
+ return action_call
2351
+
2352
+ # If we get here, it's an await of a non-action function
2353
+ self._check_non_action_await(node)
2354
+ return None
2355
+
2356
+ return None
2357
+
2358
+ def _is_run_action_call(self, node: ast.Call) -> bool:
2359
+ """Check if this is a self.run_action(...) call."""
2360
+ if isinstance(node.func, ast.Attribute):
2361
+ return node.func.attr == "run_action"
2362
+ return False
2363
+
2364
+ def _extract_policies_from_run_action(
2365
+ self, run_action_call: ast.Call, action_call: ir.ActionCall
2366
+ ) -> None:
2367
+ """Extract retry and timeout policies from run_action kwargs.
2368
+
2369
+ Parses patterns like:
2370
+ - self.run_action(action(), retry=RetryPolicy(attempts=3))
2371
+ - self.run_action(action(), timeout=timedelta(seconds=30))
2372
+ - self.run_action(action(), timeout=60)
2373
+ """
2374
+ for kw in run_action_call.keywords:
2375
+ if kw.arg == "retry":
2376
+ retry_policy = self._parse_retry_policy(kw.value)
2377
+ if retry_policy:
2378
+ policy_bracket = ir.PolicyBracket()
2379
+ policy_bracket.retry.CopyFrom(retry_policy)
2380
+ action_call.policies.append(policy_bracket)
2381
+ elif kw.arg == "timeout":
2382
+ timeout_policy = self._parse_timeout_policy(kw.value)
2383
+ if timeout_policy:
2384
+ policy_bracket = ir.PolicyBracket()
2385
+ policy_bracket.timeout.CopyFrom(timeout_policy)
2386
+ action_call.policies.append(policy_bracket)
2387
+
2388
+ def _parse_retry_policy(self, node: ast.expr) -> Optional[ir.RetryPolicy]:
2389
+ """Parse a RetryPolicy(...) call into IR.
2390
+
2391
+ Supports:
2392
+ - RetryPolicy(attempts=3)
2393
+ - RetryPolicy(attempts=3, exception_types=["ValueError"])
2394
+ - RetryPolicy(attempts=3, backoff_seconds=5)
2395
+ """
2396
+ if not isinstance(node, ast.Call):
2397
+ return None
2398
+
2399
+ # Check if it's a RetryPolicy call
2400
+ func_name = None
2401
+ if isinstance(node.func, ast.Name):
2402
+ func_name = node.func.id
2403
+ elif isinstance(node.func, ast.Attribute):
2404
+ func_name = node.func.attr
2405
+
2406
+ if func_name != "RetryPolicy":
2407
+ return None
2408
+
2409
+ policy = ir.RetryPolicy()
2410
+
2411
+ for kw in node.keywords:
2412
+ if kw.arg == "attempts" and isinstance(kw.value, ast.Constant):
2413
+ policy.max_retries = kw.value.value
2414
+ elif kw.arg == "exception_types" and isinstance(kw.value, ast.List):
2415
+ for elt in kw.value.elts:
2416
+ if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
2417
+ policy.exception_types.append(elt.value)
2418
+ elif kw.arg == "backoff_seconds" and isinstance(kw.value, ast.Constant):
2419
+ policy.backoff.seconds = int(kw.value.value)
2420
+
2421
+ return policy
2422
+
2423
+ def _parse_timeout_policy(self, node: ast.expr) -> Optional[ir.TimeoutPolicy]:
2424
+ """Parse a timeout value into IR.
2425
+
2426
+ Supports:
2427
+ - timeout=60 (int seconds)
2428
+ - timeout=30.5 (float seconds)
2429
+ - timeout=timedelta(seconds=30)
2430
+ - timeout=timedelta(minutes=2)
2431
+ """
2432
+ policy = ir.TimeoutPolicy()
2433
+
2434
+ # Direct numeric value (seconds)
2435
+ if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
2436
+ policy.timeout.seconds = int(node.value)
2437
+ return policy
2438
+
2439
+ # timedelta(...) call
2440
+ if isinstance(node, ast.Call):
2441
+ func_name = None
2442
+ if isinstance(node.func, ast.Name):
2443
+ func_name = node.func.id
2444
+ elif isinstance(node.func, ast.Attribute):
2445
+ func_name = node.func.attr
2446
+
2447
+ if func_name == "timedelta":
2448
+ total_seconds = 0
2449
+ for kw in node.keywords:
2450
+ if isinstance(kw.value, ast.Constant):
2451
+ val = kw.value.value
2452
+ if kw.arg == "seconds":
2453
+ total_seconds += int(val)
2454
+ elif kw.arg == "minutes":
2455
+ total_seconds += int(val) * 60
2456
+ elif kw.arg == "hours":
2457
+ total_seconds += int(val) * 3600
2458
+ elif kw.arg == "days":
2459
+ total_seconds += int(val) * 86400
2460
+ policy.timeout.seconds = total_seconds
2461
+ return policy
2462
+
2463
+ return None
2464
+
2465
+ def _is_asyncio_sleep_call(self, node: ast.Call) -> bool:
2466
+ """Check if this is an asyncio.sleep(...) call.
2467
+
2468
+ Supports both patterns:
2469
+ - import asyncio; asyncio.sleep(1)
2470
+ - from asyncio import sleep; sleep(1)
2471
+ - from asyncio import sleep as s; s(1)
2472
+ """
2473
+ if isinstance(node.func, ast.Attribute):
2474
+ # asyncio.sleep(...) pattern
2475
+ if node.func.attr == "sleep" and isinstance(node.func.value, ast.Name):
2476
+ return node.func.value.id == "asyncio"
2477
+ elif isinstance(node.func, ast.Name):
2478
+ # sleep(...) pattern - check if it's imported from asyncio
2479
+ func_name = node.func.id
2480
+ if func_name in self._imported_names:
2481
+ imported = self._imported_names[func_name]
2482
+ return imported.module == "asyncio" and imported.original_name == "sleep"
2483
+ return False
2484
+
2485
+ def _convert_asyncio_sleep_to_action(self, node: ast.Call) -> ir.ActionCall:
2486
+ """Convert asyncio.sleep(duration) to @sleep(duration=X) action call.
2487
+
2488
+ This creates a built-in sleep action that the scheduler handles as a
2489
+ durable sleep - stored in the DB with a future scheduled_at time.
2490
+ """
2491
+ action_call = ir.ActionCall(action_name="sleep")
2492
+
2493
+ # Extract duration argument (positional or keyword)
2494
+ if node.args:
2495
+ # asyncio.sleep(1) - positional
2496
+ expr = _expr_to_ir(node.args[0])
2497
+ if expr:
2498
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2499
+ elif node.keywords:
2500
+ # asyncio.sleep(seconds=1) - keyword (less common)
2501
+ for kw in node.keywords:
2502
+ if kw.arg in ("seconds", "delay", "duration"):
2503
+ expr = _expr_to_ir(kw.value)
2504
+ if expr:
2505
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2506
+ break
2507
+
2508
+ return action_call
2509
+
2510
+ def _is_asyncio_gather_call(self, node: ast.Call) -> bool:
2511
+ """Check if this is an asyncio.gather(...) call.
2512
+
2513
+ Supports both patterns:
2514
+ - import asyncio; asyncio.gather(a(), b())
2515
+ - from asyncio import gather; gather(a(), b())
2516
+ - from asyncio import gather as g; g(a(), b())
2517
+ """
2518
+ if isinstance(node.func, ast.Attribute):
2519
+ # asyncio.gather(...) pattern
2520
+ if node.func.attr == "gather" and isinstance(node.func.value, ast.Name):
2521
+ return node.func.value.id == "asyncio"
2522
+ elif isinstance(node.func, ast.Name):
2523
+ # gather(...) pattern - check if it's imported from asyncio
2524
+ func_name = node.func.id
2525
+ if func_name in self._imported_names:
2526
+ imported = self._imported_names[func_name]
2527
+ return imported.module == "asyncio" and imported.original_name == "gather"
2528
+ return False
2529
+
2530
+ def _convert_asyncio_gather(
2531
+ self, node: ast.Call
2532
+ ) -> Optional[Union[ir.ParallelExpr, ir.SpreadExpr]]:
2533
+ """Convert asyncio.gather(...) to ParallelExpr or SpreadExpr IR.
2534
+
2535
+ Handles two patterns:
2536
+ 1. Static gather: asyncio.gather(a(), b(), c()) -> ParallelExpr
2537
+ 2. Spread gather: asyncio.gather(*[action(x) for x in items]) -> SpreadExpr
2538
+
2539
+ Args:
2540
+ node: The asyncio.gather() Call node
2541
+
2542
+ Returns:
2543
+ A ParallelExpr, SpreadExpr, or None if conversion fails.
2544
+ """
2545
+ # Check for starred expressions - spread pattern
2546
+ if len(node.args) == 1 and isinstance(node.args[0], ast.Starred):
2547
+ starred = node.args[0]
2548
+ # Only list comprehensions are supported for spread
2549
+ if isinstance(starred.value, ast.ListComp):
2550
+ return self._convert_listcomp_to_spread_expr(starred.value)
2551
+ else:
2552
+ # Spreading a variable or other expression is not supported
2553
+ line = getattr(node, "lineno", None)
2554
+ col = getattr(node, "col_offset", None)
2555
+ if isinstance(starred.value, ast.Name):
2556
+ var_name = starred.value.id
2557
+ raise UnsupportedPatternError(
2558
+ f"Spreading variable '{var_name}' in asyncio.gather() is not supported",
2559
+ RECOMMENDATIONS["gather_variable_spread"],
2560
+ line=line,
2561
+ col=col,
2562
+ )
2563
+ else:
2564
+ raise UnsupportedPatternError(
2565
+ "Spreading non-list-comprehension expressions in asyncio.gather() is not supported",
2566
+ RECOMMENDATIONS["gather_variable_spread"],
2567
+ line=line,
2568
+ col=col,
2569
+ )
2570
+
2571
+ # Standard case: gather(a(), b(), c()) -> ParallelExpr
2572
+ parallel = ir.ParallelExpr()
2573
+
2574
+ # Each argument to gather() should be an action call
2575
+ for arg in node.args:
2576
+ call = self._convert_gather_arg_to_call(arg)
2577
+ if call:
2578
+ parallel.calls.append(call)
2579
+
2580
+ # Only return if we have calls
2581
+ if not parallel.calls:
2582
+ return None
2583
+
2584
+ return parallel
2585
+
2586
+ def _convert_listcomp_to_spread_expr(self, listcomp: ast.ListComp) -> Optional[ir.SpreadExpr]:
2587
+ """Convert a list comprehension to SpreadExpr IR.
2588
+
2589
+ Handles patterns like:
2590
+ [action(x=item) for item in collection]
2591
+
2592
+ The comprehension must have exactly one generator with no conditions,
2593
+ and the element must be an action call.
2594
+
2595
+ Args:
2596
+ listcomp: The ListComp AST node
2597
+
2598
+ Returns:
2599
+ A SpreadExpr, or None if conversion fails.
2600
+ """
2601
+ # Only support simple list comprehensions with one generator
2602
+ if len(listcomp.generators) != 1:
2603
+ line = getattr(listcomp, "lineno", None)
2604
+ col = getattr(listcomp, "col_offset", None)
2605
+ raise UnsupportedPatternError(
2606
+ "Spread pattern only supports a single loop variable",
2607
+ "Use a simple list comprehension: [action(x) for x in items]",
2608
+ line=line,
2609
+ col=col,
2610
+ )
2611
+
2612
+ gen = listcomp.generators[0]
2613
+
2614
+ # Check for conditions - not supported
2615
+ if gen.ifs:
2616
+ line = getattr(listcomp, "lineno", None)
2617
+ col = getattr(listcomp, "col_offset", None)
2618
+ raise UnsupportedPatternError(
2619
+ "Spread pattern does not support conditions in list comprehension",
2620
+ "Remove the 'if' clause from the comprehension",
2621
+ line=line,
2622
+ col=col,
2623
+ )
2624
+
2625
+ # Get the loop variable name
2626
+ if not isinstance(gen.target, ast.Name):
2627
+ line = getattr(listcomp, "lineno", None)
2628
+ col = getattr(listcomp, "col_offset", None)
2629
+ raise UnsupportedPatternError(
2630
+ "Spread pattern requires a simple loop variable",
2631
+ "Use a simple variable: [action(x) for x in items]",
2632
+ line=line,
2633
+ col=col,
2634
+ )
2635
+ loop_var = gen.target.id
2636
+
2637
+ # Get the collection expression
2638
+ collection_expr = _expr_to_ir(gen.iter)
2639
+ if not collection_expr:
2640
+ line = getattr(listcomp, "lineno", None)
2641
+ col = getattr(listcomp, "col_offset", None)
2642
+ raise UnsupportedPatternError(
2643
+ "Could not convert collection expression in spread pattern",
2644
+ "Ensure the collection is a simple variable or expression",
2645
+ line=line,
2646
+ col=col,
2647
+ )
2648
+
2649
+ # The element must be an action call
2650
+ if not isinstance(listcomp.elt, ast.Call):
2651
+ line = getattr(listcomp, "lineno", None)
2652
+ col = getattr(listcomp, "col_offset", None)
2653
+ raise UnsupportedPatternError(
2654
+ "Spread pattern requires an action call in the list comprehension",
2655
+ "Use: [action(x=item) for item in items]",
2656
+ line=line,
2657
+ col=col,
2658
+ )
2659
+
2660
+ action_call = self._extract_action_call_from_call(listcomp.elt)
2661
+ if not action_call:
2662
+ line = getattr(listcomp, "lineno", None)
2663
+ col = getattr(listcomp, "col_offset", None)
2664
+ raise UnsupportedPatternError(
2665
+ "Spread pattern element must be an @action call",
2666
+ "Ensure the function is decorated with @action",
2667
+ line=line,
2668
+ col=col,
2669
+ )
2670
+
2671
+ # Build the SpreadExpr
2672
+ spread = ir.SpreadExpr()
2673
+ spread.collection.CopyFrom(collection_expr)
2674
+ spread.loop_var = loop_var
2675
+ spread.action.CopyFrom(action_call)
2676
+
2677
+ return spread
2678
+
2679
+ def _convert_gather_arg_to_call(self, node: ast.expr) -> Optional[ir.Call]:
2680
+ """Convert a gather argument to an IR Call.
2681
+
2682
+ Handles both action calls and regular function calls.
2683
+ """
2684
+ if not isinstance(node, ast.Call):
2685
+ return None
2686
+
2687
+ # Try to extract as an action call first
2688
+ action_call = self._extract_action_call_from_call(node)
2689
+ if action_call:
2690
+ call = ir.Call()
2691
+ call.action.CopyFrom(action_call)
2692
+ return call
2693
+
2694
+ # Fall back to regular function call
2695
+ func_call = self._convert_to_function_call(node)
2696
+ if func_call:
2697
+ call = ir.Call()
2698
+ call.function.CopyFrom(func_call)
2699
+ return call
2700
+
2701
+ return None
2702
+
2703
+ def _convert_to_function_call(self, node: ast.Call) -> Optional[ir.FunctionCall]:
2704
+ """Convert an AST Call to IR FunctionCall."""
2705
+ func_name = self._get_func_name(node.func)
2706
+ if not func_name:
2707
+ return None
2708
+
2709
+ fn_call = ir.FunctionCall(name=func_name)
2710
+
2711
+ # Add positional args
2712
+ for arg in node.args:
2713
+ expr = _expr_to_ir(arg)
2714
+ if expr:
2715
+ fn_call.args.append(expr)
2716
+
2717
+ # Add keyword args
2718
+ for kw in node.keywords:
2719
+ if kw.arg:
2720
+ expr = _expr_to_ir(kw.value)
2721
+ if expr:
2722
+ fn_call.kwargs.append(ir.Kwarg(name=kw.arg, value=expr))
2723
+
2724
+ return fn_call
2725
+
2726
+ def _get_func_name(self, node: ast.expr) -> Optional[str]:
2727
+ """Get function name from a func node."""
2728
+ if isinstance(node, ast.Name):
2729
+ return node.id
2730
+ elif isinstance(node, ast.Attribute):
2731
+ # Handle chained attributes like obj.method
2732
+ parts = []
2733
+ current = node
2734
+ while isinstance(current, ast.Attribute):
2735
+ parts.append(current.attr)
2736
+ current = current.value
2737
+ if isinstance(current, ast.Name):
2738
+ parts.append(current.id)
2739
+ name = ".".join(reversed(parts))
2740
+ if name.startswith("self."):
2741
+ return name[5:]
2742
+ return name
2743
+ return None
2744
+
2745
+ def _expr_to_ir_with_model_coercion(self, node: ast.expr) -> Optional[ir.Expr]:
2746
+ """Convert an AST expression to IR, converting model constructors to dicts.
2747
+
2748
+ This is used for action arguments where Pydantic models or dataclass
2749
+ constructors should be converted to dict expressions that Rust can evaluate.
2750
+
2751
+ If the expression is a model constructor (e.g., MyModel(field=value)),
2752
+ it is converted to a dict expression. Otherwise, falls back to the
2753
+ standard _expr_to_ir conversion.
2754
+ """
2755
+ # Check if this is a model constructor call
2756
+ if isinstance(node, ast.Call):
2757
+ model_name = self._is_model_constructor(node)
2758
+ if model_name:
2759
+ return self._convert_model_constructor_to_dict(node, model_name)
2760
+
2761
+ # Fall back to standard expression conversion
2762
+ return _expr_to_ir(node)
2763
+
2764
+ def _extract_action_call_from_awaitable(self, node: ast.expr) -> Optional[ir.ActionCall]:
2765
+ """Extract action call from an awaitable expression."""
2766
+ if isinstance(node, ast.Call):
2767
+ return self._extract_action_call_from_call(node)
2768
+ return None
2769
+
2770
+ def _extract_action_call_from_call(self, node: ast.Call) -> Optional[ir.ActionCall]:
2771
+ """Extract action call info from a Call node.
2772
+
2773
+ Converts positional arguments to keyword arguments using the action's
2774
+ signature introspection. This ensures all arguments are named in the IR.
2775
+
2776
+ Pydantic models and dataclass constructors passed as arguments are
2777
+ automatically converted to dict expressions.
2778
+ """
2779
+ action_name = self._get_action_name(node.func)
2780
+ if not action_name:
2781
+ return None
2782
+
2783
+ if action_name not in self._action_defs:
2784
+ return None
2785
+
2786
+ action_def = self._action_defs[action_name]
2787
+ action_call = ir.ActionCall(action_name=action_def.action_name)
2788
+
2789
+ # Set the module name so the worker knows where to find the action
2790
+ if action_def.module_name:
2791
+ action_call.module_name = action_def.module_name
2792
+
2793
+ # Get parameter names from signature for positional arg conversion
2794
+ param_names = list(action_def.signature.parameters.keys())
2795
+
2796
+ # Convert positional args to kwargs using signature introspection
2797
+ # Model constructors are converted to dict expressions
2798
+ for i, arg in enumerate(node.args):
2799
+ if i < len(param_names):
2800
+ expr = self._expr_to_ir_with_model_coercion(arg)
2801
+ if expr:
2802
+ kwarg = ir.Kwarg(name=param_names[i], value=expr)
2803
+ action_call.kwargs.append(kwarg)
2804
+
2805
+ # Add explicit kwargs
2806
+ # Model constructors are converted to dict expressions
2807
+ for kw in node.keywords:
2808
+ if kw.arg:
2809
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
2810
+ if expr:
2811
+ kwarg = ir.Kwarg(name=kw.arg, value=expr)
2812
+ action_call.kwargs.append(kwarg)
2813
+
2814
+ return action_call
2815
+
2816
+ def _get_action_name(self, func: ast.expr) -> Optional[str]:
2817
+ """Get the action name from a function expression."""
2818
+ if isinstance(func, ast.Name):
2819
+ return func.id
2820
+ elif isinstance(func, ast.Attribute):
2821
+ return func.attr
2822
+ return None
2823
+
2824
+ def _get_assign_target(self, targets: List[ast.expr]) -> Optional[str]:
2825
+ """Get the target variable name from assignment targets (single target only)."""
2826
+ if targets and isinstance(targets[0], ast.Name):
2827
+ return targets[0].id
2828
+ return None
2829
+
2830
+ def _get_assign_targets(self, targets: List[ast.expr]) -> List[str]:
2831
+ """Get all target variable names from assignment targets (including tuple unpacking)."""
2832
+ result: List[str] = []
2833
+ for t in targets:
2834
+ if isinstance(t, ast.Name):
2835
+ result.append(t.id)
2836
+ elif isinstance(t, ast.Subscript):
2837
+ formatted = _format_subscript_target(t)
2838
+ if formatted:
2839
+ result.append(formatted)
2840
+ elif isinstance(t, ast.Tuple):
2841
+ for elt in t.elts:
2842
+ if isinstance(elt, ast.Name):
2843
+ result.append(elt.id)
2844
+ return result
2845
+
2846
+
2847
+ def _make_span(node: ast.AST) -> ir.Span:
2848
+ """Create a Span from an AST node."""
2849
+ return ir.Span(
2850
+ start_line=getattr(node, "lineno", 0),
2851
+ start_col=getattr(node, "col_offset", 0),
2852
+ end_line=getattr(node, "end_lineno", 0) or 0,
2853
+ end_col=getattr(node, "end_col_offset", 0) or 0,
2854
+ )
2855
+
2856
+
2857
+ def _expr_to_ir(expr: ast.AST) -> Optional[ir.Expr]:
2858
+ """Convert Python AST expression to IR Expr."""
2859
+ result = ir.Expr(span=_make_span(expr))
2860
+
2861
+ if isinstance(expr, ast.Name):
2862
+ result.variable.CopyFrom(ir.Variable(name=expr.id))
2863
+ return result
2864
+
2865
+ if isinstance(expr, ast.Constant):
2866
+ literal = _constant_to_literal(expr.value)
2867
+ if literal:
2868
+ result.literal.CopyFrom(literal)
2869
+ return result
2870
+
2871
+ if isinstance(expr, ast.BinOp):
2872
+ left = _expr_to_ir(expr.left)
2873
+ right = _expr_to_ir(expr.right)
2874
+ op = _bin_op_to_ir(expr.op)
2875
+ if left and right and op:
2876
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
2877
+ return result
2878
+
2879
+ if isinstance(expr, ast.UnaryOp):
2880
+ operand = _expr_to_ir(expr.operand)
2881
+ op = _unary_op_to_ir(expr.op)
2882
+ if operand and op:
2883
+ result.unary_op.CopyFrom(ir.UnaryOp(op=op, operand=operand))
2884
+ return result
2885
+
2886
+ if isinstance(expr, ast.Compare):
2887
+ left = _expr_to_ir(expr.left)
2888
+ if not left:
2889
+ return None
2890
+ # For simplicity, handle single comparison
2891
+ if expr.ops and expr.comparators:
2892
+ op = _cmp_op_to_ir(expr.ops[0])
2893
+ right = _expr_to_ir(expr.comparators[0])
2894
+ if op and right:
2895
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
2896
+ return result
2897
+
2898
+ if isinstance(expr, ast.BoolOp):
2899
+ values = [_expr_to_ir(v) for v in expr.values]
2900
+ if all(v for v in values):
2901
+ op = _bool_op_to_ir(expr.op)
2902
+ if op and len(values) >= 2:
2903
+ # Chain boolean ops: a and b and c -> (a and b) and c
2904
+ result_expr = values[0]
2905
+ for v in values[1:]:
2906
+ if result_expr and v:
2907
+ new_result = ir.Expr()
2908
+ new_result.binary_op.CopyFrom(ir.BinaryOp(left=result_expr, op=op, right=v))
2909
+ result_expr = new_result
2910
+ return result_expr
2911
+
2912
+ if isinstance(expr, ast.List):
2913
+ elements = [_expr_to_ir(e) for e in expr.elts]
2914
+ if all(e for e in elements):
2915
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
2916
+ result.list.CopyFrom(list_expr)
2917
+ return result
2918
+
2919
+ if isinstance(expr, ast.Dict):
2920
+ entries: List[ir.DictEntry] = []
2921
+ for k, v in zip(expr.keys, expr.values, strict=False):
2922
+ if k:
2923
+ key_expr = _expr_to_ir(k)
2924
+ value_expr = _expr_to_ir(v)
2925
+ if key_expr and value_expr:
2926
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2927
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2928
+ return result
2929
+
2930
+ if isinstance(expr, ast.Subscript):
2931
+ obj = _expr_to_ir(expr.value)
2932
+ index = _expr_to_ir(expr.slice) if isinstance(expr.slice, ast.AST) else None
2933
+ if obj and index:
2934
+ result.index.CopyFrom(ir.IndexAccess(object=obj, index=index))
2935
+ return result
2936
+
2937
+ if isinstance(expr, ast.Attribute):
2938
+ obj = _expr_to_ir(expr.value)
2939
+ if obj:
2940
+ result.dot.CopyFrom(ir.DotAccess(object=obj, attribute=expr.attr))
2941
+ return result
2942
+
2943
+ if isinstance(expr, ast.Await) and isinstance(expr.value, ast.Call):
2944
+ func_name = _get_func_name(expr.value.func)
2945
+ if func_name:
2946
+ args = [_expr_to_ir(a) for a in expr.value.args]
2947
+ kwargs: List[ir.Kwarg] = []
2948
+ for kw in expr.value.keywords:
2949
+ if kw.arg:
2950
+ kw_expr = _expr_to_ir(kw.value)
2951
+ if kw_expr:
2952
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
2953
+ func_call = ir.FunctionCall(
2954
+ name=func_name,
2955
+ args=[a for a in args if a],
2956
+ kwargs=kwargs,
2957
+ )
2958
+ result.function_call.CopyFrom(func_call)
2959
+ return result
2960
+
2961
+ if isinstance(expr, ast.Call):
2962
+ # Function call
2963
+ func_name = _get_func_name(expr.func)
2964
+ if func_name:
2965
+ args = [_expr_to_ir(a) for a in expr.args]
2966
+ kwargs: List[ir.Kwarg] = []
2967
+ for kw in expr.keywords:
2968
+ if kw.arg:
2969
+ kw_expr = _expr_to_ir(kw.value)
2970
+ if kw_expr:
2971
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
2972
+ func_call = ir.FunctionCall(
2973
+ name=func_name,
2974
+ args=[a for a in args if a],
2975
+ kwargs=kwargs,
2976
+ )
2977
+ result.function_call.CopyFrom(func_call)
2978
+ return result
2979
+
2980
+ if isinstance(expr, ast.Tuple):
2981
+ # Handle tuple as list for now
2982
+ elements = [_expr_to_ir(e) for e in expr.elts]
2983
+ if all(e for e in elements):
2984
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
2985
+ result.list.CopyFrom(list_expr)
2986
+ return result
2987
+
2988
+ # Check for unsupported expression types
2989
+ _check_unsupported_expression(expr)
2990
+
2991
+ return None
2992
+
2993
+
2994
+ def _check_unsupported_expression(expr: ast.AST) -> None:
2995
+ """Check for unsupported expression types and raise descriptive errors."""
2996
+ line = getattr(expr, "lineno", None)
2997
+ col = getattr(expr, "col_offset", None)
2998
+
2999
+ if isinstance(expr, ast.Constant):
3000
+ if _constant_to_literal(expr.value) is None:
3001
+ raise UnsupportedPatternError(
3002
+ f"Unsupported literal type '{type(expr.value).__name__}'",
3003
+ RECOMMENDATIONS["unsupported_literal"],
3004
+ line=line,
3005
+ col=col,
3006
+ )
3007
+
3008
+ if isinstance(expr, ast.JoinedStr):
3009
+ raise UnsupportedPatternError(
3010
+ "F-strings are not supported",
3011
+ RECOMMENDATIONS["fstring"],
3012
+ line=line,
3013
+ col=col,
3014
+ )
3015
+ elif isinstance(expr, ast.Lambda):
3016
+ raise UnsupportedPatternError(
3017
+ "Lambda expressions are not supported",
3018
+ RECOMMENDATIONS["lambda"],
3019
+ line=line,
3020
+ col=col,
3021
+ )
3022
+ elif isinstance(expr, ast.ListComp):
3023
+ raise UnsupportedPatternError(
3024
+ "List comprehensions are not supported in this context",
3025
+ RECOMMENDATIONS["list_comprehension"],
3026
+ line=line,
3027
+ col=col,
3028
+ )
3029
+ elif isinstance(expr, ast.DictComp):
3030
+ raise UnsupportedPatternError(
3031
+ "Dict comprehensions are not supported in this context",
3032
+ RECOMMENDATIONS["dict_comprehension"],
3033
+ line=line,
3034
+ col=col,
3035
+ )
3036
+ elif isinstance(expr, ast.SetComp):
3037
+ raise UnsupportedPatternError(
3038
+ "Set comprehensions are not supported",
3039
+ RECOMMENDATIONS["set_comprehension"],
3040
+ line=line,
3041
+ col=col,
3042
+ )
3043
+ elif isinstance(expr, ast.GeneratorExp):
3044
+ raise UnsupportedPatternError(
3045
+ "Generator expressions are not supported",
3046
+ RECOMMENDATIONS["generator"],
3047
+ line=line,
3048
+ col=col,
3049
+ )
3050
+ elif isinstance(expr, ast.NamedExpr):
3051
+ raise UnsupportedPatternError(
3052
+ "The walrus operator (:=) is not supported",
3053
+ RECOMMENDATIONS["walrus"],
3054
+ line=line,
3055
+ col=col,
3056
+ )
3057
+ elif isinstance(expr, ast.Yield) or isinstance(expr, ast.YieldFrom):
3058
+ raise UnsupportedPatternError(
3059
+ "Yield expressions are not supported",
3060
+ RECOMMENDATIONS["yield_statement"],
3061
+ line=line,
3062
+ col=col,
3063
+ )
3064
+ elif isinstance(expr, ast.expr):
3065
+ raise UnsupportedPatternError(
3066
+ f"Unsupported expression type '{type(expr).__name__}'",
3067
+ RECOMMENDATIONS["unsupported_expression"],
3068
+ line=line,
3069
+ col=col,
3070
+ )
3071
+
3072
+
3073
+ def _format_subscript_target(target: ast.Subscript) -> Optional[str]:
3074
+ """Convert a subscript target to a string representation for tracking."""
3075
+ if not isinstance(target.value, ast.Name):
3076
+ return None
3077
+
3078
+ base = target.value.id
3079
+ try:
3080
+ index_str = ast.unparse(target.slice)
3081
+ except Exception:
3082
+ return None
3083
+
3084
+ return f"{base}[{index_str}]"
3085
+
3086
+
3087
+ def _constant_to_literal(value: Any) -> Optional[ir.Literal]:
3088
+ """Convert a Python constant to IR Literal."""
3089
+ literal = ir.Literal()
3090
+ if value is None:
3091
+ literal.is_none = True
3092
+ elif isinstance(value, bool):
3093
+ literal.bool_value = value
3094
+ elif isinstance(value, int):
3095
+ literal.int_value = value
3096
+ elif isinstance(value, float):
3097
+ literal.float_value = value
3098
+ elif isinstance(value, str):
3099
+ literal.string_value = value
3100
+ else:
3101
+ return None
3102
+ return literal
3103
+
3104
+
3105
+ def _bin_op_to_ir(op: ast.operator) -> Optional[ir.BinaryOperator]:
3106
+ """Convert Python binary operator to IR BinaryOperator."""
3107
+ mapping = {
3108
+ ast.Add: ir.BinaryOperator.BINARY_OP_ADD,
3109
+ ast.Sub: ir.BinaryOperator.BINARY_OP_SUB,
3110
+ ast.Mult: ir.BinaryOperator.BINARY_OP_MUL,
3111
+ ast.Div: ir.BinaryOperator.BINARY_OP_DIV,
3112
+ ast.FloorDiv: ir.BinaryOperator.BINARY_OP_FLOOR_DIV,
3113
+ ast.Mod: ir.BinaryOperator.BINARY_OP_MOD,
3114
+ }
3115
+ return mapping.get(type(op))
3116
+
3117
+
3118
+ def _unary_op_to_ir(op: ast.unaryop) -> Optional[ir.UnaryOperator]:
3119
+ """Convert Python unary operator to IR UnaryOperator."""
3120
+ mapping = {
3121
+ ast.USub: ir.UnaryOperator.UNARY_OP_NEG,
3122
+ ast.Not: ir.UnaryOperator.UNARY_OP_NOT,
3123
+ }
3124
+ return mapping.get(type(op))
3125
+
3126
+
3127
+ def _cmp_op_to_ir(op: ast.cmpop) -> Optional[ir.BinaryOperator]:
3128
+ """Convert Python comparison operator to IR BinaryOperator."""
3129
+ mapping = {
3130
+ ast.Eq: ir.BinaryOperator.BINARY_OP_EQ,
3131
+ ast.NotEq: ir.BinaryOperator.BINARY_OP_NE,
3132
+ ast.Lt: ir.BinaryOperator.BINARY_OP_LT,
3133
+ ast.LtE: ir.BinaryOperator.BINARY_OP_LE,
3134
+ ast.Gt: ir.BinaryOperator.BINARY_OP_GT,
3135
+ ast.GtE: ir.BinaryOperator.BINARY_OP_GE,
3136
+ ast.In: ir.BinaryOperator.BINARY_OP_IN,
3137
+ ast.NotIn: ir.BinaryOperator.BINARY_OP_NOT_IN,
3138
+ }
3139
+ return mapping.get(type(op))
3140
+
3141
+
3142
+ def _bool_op_to_ir(op: ast.boolop) -> Optional[ir.BinaryOperator]:
3143
+ """Convert Python boolean operator to IR BinaryOperator."""
3144
+ mapping = {
3145
+ ast.And: ir.BinaryOperator.BINARY_OP_AND,
3146
+ ast.Or: ir.BinaryOperator.BINARY_OP_OR,
3147
+ }
3148
+ return mapping.get(type(op))
3149
+
3150
+
3151
+ def _get_func_name(func: ast.expr) -> Optional[str]:
3152
+ """Get function name from a Call's func attribute."""
3153
+ if isinstance(func, ast.Name):
3154
+ return func.id
3155
+ elif isinstance(func, ast.Attribute):
3156
+ # For method calls like obj.method, return full dotted name
3157
+ parts = []
3158
+ current = func
3159
+ while isinstance(current, ast.Attribute):
3160
+ parts.append(current.attr)
3161
+ current = current.value
3162
+ if isinstance(current, ast.Name):
3163
+ parts.append(current.id)
3164
+ name = ".".join(reversed(parts))
3165
+ if name.startswith("self."):
3166
+ return name[5:]
3167
+ return name
3168
+ return None