rappel 0.7.2__py3-none-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rappel might be problematic. Click here for more details.

rappel/ir_builder.py ADDED
@@ -0,0 +1,3664 @@
1
+ """
2
+ IR Builder - Converts Python workflow AST to Rappel IR (ast.proto).
3
+
4
+ This module parses Python workflow classes and produces the IR representation
5
+ that can be sent to the Rust runtime for execution.
6
+
7
+ The IR builder performs deep transformations to convert Python patterns into
8
+ valid Rappel IR structures. Control flow bodies (try, for, if branches) are
9
+ represented as full blocks, so they can contain arbitrary statements and
10
+ `return` behaves consistently with Python (returns from the entry function end
11
+ the workflow).
12
+
13
+ Validation:
14
+ The IR builder proactively detects unsupported Python patterns and raises
15
+ UnsupportedPatternError with clear recommendations for how to rewrite the code.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ import copy
22
+ import inspect
23
+ import textwrap
24
+ from dataclasses import dataclass
25
+ from enum import EnumMeta
26
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, NoReturn, Optional, Set, Union
27
+
28
+ from proto import ast_pb2 as ir
29
+ from rappel.registry import registry
30
+
31
+
32
+ class UnsupportedPatternError(Exception):
33
+ """Raised when the IR builder encounters an unsupported Python pattern.
34
+
35
+ This error includes a recommendation for how to rewrite the code to use
36
+ supported patterns.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ message: str,
42
+ recommendation: str,
43
+ line: Optional[int] = None,
44
+ col: Optional[int] = None,
45
+ filename: Optional[str] = None,
46
+ ):
47
+ self.message = message
48
+ self.recommendation = recommendation
49
+ self.line = line
50
+ self.col = col
51
+ self.filename = filename
52
+
53
+ location_parts: List[str] = []
54
+ if filename:
55
+ location_parts.append(filename)
56
+ if line:
57
+ location_parts.append(f"line {line}")
58
+ if col is not None:
59
+ location_parts.append(f"col {col}")
60
+ location = f" ({', '.join(location_parts)})" if location_parts else ""
61
+ full_message = f"{message}{location}\n\nRecommendation: {recommendation}"
62
+ super().__init__(full_message)
63
+
64
+
65
+ # Recommendations for common unsupported patterns
66
+ RECOMMENDATIONS = {
67
+ "constructor_return": (
68
+ "Returning a class constructor (like MyModel(...)) directly is not supported.\n"
69
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
70
+ "Use an @action to create the object:\n\n"
71
+ " @action\n"
72
+ " async def build_result(items: list, count: int) -> MyResult:\n"
73
+ " return MyResult(items=items, count=count)\n\n"
74
+ " # In workflow:\n"
75
+ " return await build_result(items, count)"
76
+ ),
77
+ "constructor_assignment": (
78
+ "Assigning a class constructor result (like x = MyClass(...)) is not supported.\n"
79
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
80
+ "Use an @action to create the object:\n\n"
81
+ " @action\n"
82
+ " async def create_config(value: int) -> Config:\n"
83
+ " return Config(value=value)\n\n"
84
+ " # In workflow:\n"
85
+ " config = await create_config(value)"
86
+ ),
87
+ "non_action_call": (
88
+ "Calling a function that is not decorated with @action is not supported.\n"
89
+ "Only @action decorated functions can be awaited in workflow code.\n\n"
90
+ "Add the @action decorator to your function:\n\n"
91
+ " @action\n"
92
+ " async def my_function(x: int) -> int:\n"
93
+ " return x * 2"
94
+ ),
95
+ "sync_function_call": (
96
+ "Calling a synchronous function directly in workflow code is not supported.\n"
97
+ "All computation must happen inside @action decorated async functions.\n\n"
98
+ "Wrap your logic in an @action:\n\n"
99
+ " @action\n"
100
+ " async def compute(x: int) -> int:\n"
101
+ " return some_sync_function(x)"
102
+ ),
103
+ "method_call_non_self": (
104
+ "Calling methods on objects other than 'self' is not supported in workflow code.\n"
105
+ "Use an @action to perform method calls:\n\n"
106
+ " @action\n"
107
+ " async def call_method(obj: MyClass) -> Result:\n"
108
+ " return obj.some_method()"
109
+ ),
110
+ "builtin_call": (
111
+ "Calling built-in functions like len(), str(), int() directly is not supported.\n"
112
+ "Use an @action to perform these operations:\n\n"
113
+ " @action\n"
114
+ " async def get_length(items: list) -> int:\n"
115
+ " return len(items)"
116
+ ),
117
+ "fstring": (
118
+ "F-strings are not supported in workflow code because they require "
119
+ "runtime string interpolation.\n"
120
+ "Use an @action to perform string formatting:\n\n"
121
+ " @action\n"
122
+ " async def format_message(value: int) -> str:\n"
123
+ " return f'Result: {value}'"
124
+ ),
125
+ "delete": (
126
+ "The 'del' statement is not supported in workflow code.\n"
127
+ "Use an @action to perform mutations:\n\n"
128
+ " @action\n"
129
+ " async def remove_key(data: dict, key: str) -> dict:\n"
130
+ " del data[key]\n"
131
+ " return data"
132
+ ),
133
+ "while_loop": (
134
+ "While loops are not supported in workflow code because they can run "
135
+ "indefinitely.\n"
136
+ "Use a for loop with a fixed range, or restructure as recursive workflow calls."
137
+ ),
138
+ "with_statement": (
139
+ "Context managers (with statements) are not supported in workflow code.\n"
140
+ "Use an @action to handle resource management:\n\n"
141
+ " @action\n"
142
+ " async def read_file(path: str) -> str:\n"
143
+ " with open(path) as f:\n"
144
+ " return f.read()"
145
+ ),
146
+ "raise_statement": (
147
+ "The 'raise' statement is not supported directly in workflow code.\n"
148
+ "Use an @action that raises exceptions, or return error values."
149
+ ),
150
+ "assert_statement": (
151
+ "Assert statements are not supported in workflow code.\n"
152
+ "Use an @action for validation, or use if statements with explicit error handling."
153
+ ),
154
+ "lambda": (
155
+ "Lambda expressions are not supported in workflow code.\n"
156
+ "Use an @action to define the function logic."
157
+ ),
158
+ "list_comprehension": (
159
+ "List comprehensions are only supported when assigned directly to a variable "
160
+ "or inside asyncio.gather(*[...]).\n"
161
+ "For other cases, use a for loop or an @action."
162
+ ),
163
+ "dict_comprehension": (
164
+ "Dict comprehensions are only supported when assigned directly to a variable.\n"
165
+ "For other cases, use a for loop or an @action:\n\n"
166
+ " @action\n"
167
+ " async def build_dict(items: list) -> dict:\n"
168
+ " return {k: v for k, v in items}"
169
+ ),
170
+ "set_comprehension": (
171
+ "Set comprehensions are not supported in workflow code.\nUse an @action to build sets."
172
+ ),
173
+ "generator": (
174
+ "Generator expressions are not supported in workflow code.\n"
175
+ "Use a list or an @action instead."
176
+ ),
177
+ "walrus": (
178
+ "The walrus operator (:=) is not supported in workflow code.\n"
179
+ "Use separate assignment statements instead."
180
+ ),
181
+ "match": (
182
+ "Match statements are not supported in workflow code.\nUse if/elif/else chains instead."
183
+ ),
184
+ "gather_variable_spread": (
185
+ "Spreading a variable in asyncio.gather() is not supported because it requires "
186
+ "data flow analysis to determine the contents.\n"
187
+ "Use a list comprehension directly in gather:\n\n"
188
+ " # Instead of:\n"
189
+ " tasks = []\n"
190
+ " for i in range(count):\n"
191
+ " tasks.append(process(value=i))\n"
192
+ " results = await asyncio.gather(*tasks)\n\n"
193
+ " # Use:\n"
194
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
195
+ ),
196
+ "for_loop_append_pattern": (
197
+ "Building a task list in a for loop then spreading in asyncio.gather() is not "
198
+ "supported.\n"
199
+ "Use a list comprehension directly in gather:\n\n"
200
+ " # Instead of:\n"
201
+ " tasks = []\n"
202
+ " for i in range(count):\n"
203
+ " tasks.append(process(value=i))\n"
204
+ " results = await asyncio.gather(*tasks)\n\n"
205
+ " # Use:\n"
206
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)])"
207
+ ),
208
+ "global_statement": (
209
+ "Global statements are not supported in workflow code.\n"
210
+ "Use workflow state or pass values explicitly."
211
+ ),
212
+ "nonlocal_statement": (
213
+ "Nonlocal statements are not supported in workflow code.\n"
214
+ "Use explicit parameter passing instead."
215
+ ),
216
+ "import_statement": (
217
+ "Import statements inside workflow run() are not supported.\n"
218
+ "Place imports at the module level."
219
+ ),
220
+ "class_def": (
221
+ "Class definitions inside workflow run() are not supported.\n"
222
+ "Define classes at the module level."
223
+ ),
224
+ "function_def": (
225
+ "Nested function definitions inside workflow run() are not supported.\n"
226
+ "Define functions at the module level or use @action."
227
+ ),
228
+ "yield_statement": (
229
+ "Yield statements are not supported in workflow code.\n"
230
+ "Workflows must return a complete result, not generate values incrementally."
231
+ ),
232
+ "continue_statement": (
233
+ "Continue statements are not supported in workflow code.\n"
234
+ "Restructure your loop using if/else to skip iterations."
235
+ ),
236
+ "unsupported_statement": (
237
+ "This statement type is not supported in workflow code.\n"
238
+ "Move the logic into an @action or rewrite using supported statements."
239
+ ),
240
+ "unsupported_expression": (
241
+ "This expression type is not supported in workflow code.\n"
242
+ "Move the logic into an @action or rewrite using supported expressions."
243
+ ),
244
+ "unsupported_literal": (
245
+ "This literal type is not supported in workflow code.\n"
246
+ "Convert the value to a supported literal type inside an @action."
247
+ ),
248
+ }
249
+
250
+ GLOBAL_FUNCTIONS = {
251
+ "enumerate": ir.GlobalFunction.GLOBAL_FUNCTION_ENUMERATE,
252
+ "len": ir.GlobalFunction.GLOBAL_FUNCTION_LEN,
253
+ "range": ir.GlobalFunction.GLOBAL_FUNCTION_RANGE,
254
+ }
255
+ ALLOWED_SYNC_FUNCTIONS = set(GLOBAL_FUNCTIONS)
256
+
257
+ _CURRENT_ACTION_NAMES: set[str] = set()
258
+
259
+
260
+ if TYPE_CHECKING:
261
+ from .workflow import Workflow
262
+
263
+
264
+ @dataclass
265
+ class ActionDefinition:
266
+ """Definition of an action function."""
267
+
268
+ action_name: str
269
+ module_name: Optional[str]
270
+ signature: inspect.Signature
271
+
272
+
273
+ @dataclass
274
+ class FunctionParameter:
275
+ """A function parameter with optional default value."""
276
+
277
+ name: str
278
+ has_default: bool
279
+ default_value: Any = None # The actual Python value if has_default is True
280
+
281
+
282
+ @dataclass
283
+ class FunctionSignatureInfo:
284
+ """Signature info for a workflow function, used to fill in default arguments."""
285
+
286
+ parameters: List[FunctionParameter]
287
+
288
+
289
+ @dataclass
290
+ class ModuleContext:
291
+ """Cached IRBuilder context derived from a module."""
292
+
293
+ action_defs: Dict[str, ActionDefinition]
294
+ imported_names: Dict[str, "ImportedName"]
295
+ module_functions: Set[str]
296
+ model_defs: Dict[str, "ModelDefinition"]
297
+
298
+
299
+ @dataclass
300
+ class TransformContext:
301
+ """Context for IR transformations."""
302
+
303
+ # Counter for generating unique function names
304
+ implicit_fn_counter: int = 0
305
+ # Implicit functions generated during transformation
306
+ implicit_functions: List[ir.FunctionDef] = None # type: ignore
307
+ # Function signatures for workflow methods (used to fill in default arguments)
308
+ function_signatures: Dict[str, FunctionSignatureInfo] = None # type: ignore
309
+
310
+ def __post_init__(self) -> None:
311
+ if self.implicit_functions is None:
312
+ self.implicit_functions = []
313
+ if self.function_signatures is None:
314
+ self.function_signatures = {}
315
+
316
+ def next_implicit_fn_name(self, prefix: str = "implicit") -> str:
317
+ """Generate a unique implicit function name."""
318
+ self.implicit_fn_counter += 1
319
+ return f"__{prefix}_{self.implicit_fn_counter}__"
320
+
321
+
322
+ def build_workflow_ir(workflow_cls: type["Workflow"]) -> ir.Program:
323
+ """Build an IR Program from a workflow class.
324
+
325
+ Args:
326
+ workflow_cls: The workflow class to convert.
327
+
328
+ Returns:
329
+ An IR Program proto message.
330
+ """
331
+ original_run = getattr(workflow_cls, "__workflow_run_impl__", None)
332
+ if original_run is None:
333
+ original_run = workflow_cls.__dict__.get("run")
334
+ if original_run is None:
335
+ raise ValueError(f"workflow {workflow_cls!r} missing run implementation")
336
+
337
+ module = inspect.getmodule(original_run)
338
+ if module is None:
339
+ raise ValueError(f"unable to locate module for workflow {workflow_cls!r}")
340
+
341
+ module_contexts: Dict[str, ModuleContext] = {}
342
+
343
+ def get_module_context(target_module: Any) -> ModuleContext:
344
+ module_name = target_module.__name__
345
+ if module_name not in module_contexts:
346
+ module_contexts[module_name] = ModuleContext(
347
+ action_defs=_discover_action_names(target_module),
348
+ imported_names=_discover_module_imports(target_module),
349
+ module_functions=_discover_module_functions(target_module),
350
+ model_defs=_discover_model_definitions(target_module),
351
+ )
352
+ return module_contexts[module_name]
353
+
354
+ # Build the IR with transformation context
355
+ ctx = TransformContext()
356
+ program = ir.Program()
357
+ function_defs: Dict[str, ir.FunctionDef] = {}
358
+
359
+ # Extract instance attributes from __init__ for policy resolution
360
+ instance_attrs = _extract_instance_attrs(workflow_cls)
361
+
362
+ def parse_function(fn: Any) -> tuple[ast.AST, Optional[str], int]:
363
+ source_lines, start_line = inspect.getsourcelines(fn)
364
+ function_source = textwrap.dedent("".join(source_lines))
365
+ filename = inspect.getsourcefile(fn)
366
+ if filename is None:
367
+ filename = inspect.getfile(fn)
368
+ return ast.parse(function_source, filename=filename or "<unknown>"), filename, start_line
369
+
370
+ def _with_source_location(
371
+ err: UnsupportedPatternError,
372
+ filename: Optional[str],
373
+ start_line: int,
374
+ ) -> UnsupportedPatternError:
375
+ line = err.line
376
+ col = err.col
377
+ if line is not None:
378
+ line = start_line + line - 1
379
+ if col is not None:
380
+ col = col + 1
381
+ return UnsupportedPatternError(
382
+ err.message,
383
+ err.recommendation,
384
+ line=line,
385
+ col=col,
386
+ filename=filename,
387
+ )
388
+
389
+ def add_function_def(
390
+ fn: Any,
391
+ fn_tree: ast.AST,
392
+ filename: Optional[str],
393
+ start_line: int,
394
+ override_name: Optional[str] = None,
395
+ ) -> None:
396
+ global _CURRENT_ACTION_NAMES
397
+ fn_module = inspect.getmodule(fn)
398
+ if fn_module is None:
399
+ raise ValueError(f"unable to locate module for function {fn!r}")
400
+
401
+ ctx_data = get_module_context(fn_module)
402
+ _CURRENT_ACTION_NAMES = set(ctx_data.action_defs.keys())
403
+ builder = IRBuilder(
404
+ ctx_data.action_defs,
405
+ ctx,
406
+ ctx_data.imported_names,
407
+ ctx_data.module_functions,
408
+ ctx_data.model_defs,
409
+ fn_module.__dict__,
410
+ instance_attrs,
411
+ )
412
+ try:
413
+ builder.visit(fn_tree)
414
+ except UnsupportedPatternError as err:
415
+ raise _with_source_location(err, filename, start_line) from err
416
+ if builder.function_def:
417
+ if override_name:
418
+ builder.function_def.name = override_name
419
+ function_defs[builder.function_def.name] = builder.function_def
420
+
421
+ # Discover all reachable helper methods first so we can pre-collect their signatures.
422
+ # This is needed because when we process a function that calls another function,
423
+ # we need to know the callee's signature to fill in default arguments.
424
+ run_tree, run_filename, run_start_line = parse_function(original_run)
425
+
426
+ # Collect all reachable methods and their trees
427
+ methods_to_process: List[tuple[Any, ast.AST, Optional[str], int, Optional[str]]] = [
428
+ (original_run, run_tree, run_filename, run_start_line, "main")
429
+ ]
430
+ pending = list(_collect_self_method_calls(run_tree))
431
+ visited: Set[str] = set()
432
+ skip_methods = {"run_action"}
433
+
434
+ while pending:
435
+ method_name = pending.pop()
436
+ if method_name in visited or method_name == "run" or method_name in skip_methods:
437
+ continue
438
+ visited.add(method_name)
439
+
440
+ method = _find_workflow_method(workflow_cls, method_name)
441
+ if method is None:
442
+ continue
443
+
444
+ method_tree, method_filename, method_start_line = parse_function(method)
445
+ methods_to_process.append((method, method_tree, method_filename, method_start_line, None))
446
+ pending.extend(_collect_self_method_calls(method_tree))
447
+
448
+ # Pre-collect signatures for all methods before processing IR.
449
+ # This ensures that when we encounter a function call, we can fill in default args.
450
+ for fn, _fn_tree, _filename, _start_line, override_name in methods_to_process:
451
+ fn_name = override_name if override_name else fn.__name__
452
+ sig = inspect.signature(fn)
453
+ params: List[FunctionParameter] = []
454
+ for param_name, param in sig.parameters.items():
455
+ if param_name == "self":
456
+ continue
457
+ has_default = param.default is not inspect.Parameter.empty
458
+ default_value = param.default if has_default else None
459
+ params.append(
460
+ FunctionParameter(
461
+ name=param_name,
462
+ has_default=has_default,
463
+ default_value=default_value,
464
+ )
465
+ )
466
+ ctx.function_signatures[fn_name] = FunctionSignatureInfo(parameters=params)
467
+
468
+ # Now process all functions with signatures already available
469
+ for fn, fn_tree, filename, start_line, override_name in methods_to_process:
470
+ add_function_def(fn, fn_tree, filename, start_line, override_name)
471
+
472
+ # Add implicit functions first (they may be called by the main function)
473
+ for implicit_fn in ctx.implicit_functions:
474
+ program.functions.append(implicit_fn)
475
+
476
+ # Add all function definitions (run + reachable helper methods)
477
+ for fn_def in function_defs.values():
478
+ program.functions.append(fn_def)
479
+
480
+ global _CURRENT_ACTION_NAMES
481
+ _CURRENT_ACTION_NAMES = set()
482
+
483
+ return program
484
+
485
+
486
+ def _collect_self_method_calls(tree: ast.AST) -> Set[str]:
487
+ """Collect self.method(...) call names from a parsed function AST."""
488
+ calls: Set[str] = set()
489
+ for node in ast.walk(tree):
490
+ if not isinstance(node, ast.Call):
491
+ continue
492
+ func = node.func
493
+ if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
494
+ if func.value.id == "self":
495
+ calls.add(func.attr)
496
+ return calls
497
+
498
+
499
+ def _find_workflow_method(workflow_cls: type["Workflow"], name: str) -> Optional[Any]:
500
+ """Find a workflow method by name across the class MRO."""
501
+ for base in workflow_cls.__mro__:
502
+ if name not in base.__dict__:
503
+ continue
504
+ value = base.__dict__[name]
505
+ if isinstance(value, staticmethod) or isinstance(value, classmethod):
506
+ return value.__func__
507
+ if inspect.isfunction(value):
508
+ return value
509
+ return None
510
+
511
+
512
+ def _extract_instance_attrs(workflow_cls: type["Workflow"]) -> Dict[str, ast.expr]:
513
+ """Extract self.attr = value assignments from the workflow's __init__ method.
514
+
515
+ Parses the __init__ method to find assignments like:
516
+ self.retry_policy = RetryPolicy(attempts=3)
517
+ self.timeout = 30
518
+
519
+ Returns a dict mapping attribute names to their AST value nodes.
520
+ """
521
+ init_method = _find_workflow_method(workflow_cls, "__init__")
522
+ if init_method is None:
523
+ return {}
524
+
525
+ try:
526
+ source_lines, _ = inspect.getsourcelines(init_method)
527
+ source = textwrap.dedent("".join(source_lines))
528
+ tree = ast.parse(source)
529
+ except (OSError, TypeError, SyntaxError):
530
+ return {}
531
+
532
+ attrs: Dict[str, ast.expr] = {}
533
+
534
+ # Walk the __init__ body looking for self.attr = value assignments
535
+ for node in ast.walk(tree):
536
+ if not isinstance(node, ast.Assign):
537
+ continue
538
+ # Only handle single-target assignments
539
+ if len(node.targets) != 1:
540
+ continue
541
+ target = node.targets[0]
542
+ # Check for self.attr pattern
543
+ if (
544
+ isinstance(target, ast.Attribute)
545
+ and isinstance(target.value, ast.Name)
546
+ and target.value.id == "self"
547
+ ):
548
+ attrs[target.attr] = node.value
549
+
550
+ return attrs
551
+
552
+
553
+ def _discover_action_names(module: Any) -> Dict[str, ActionDefinition]:
554
+ """Discover all @action decorated functions in a module."""
555
+ names: Dict[str, ActionDefinition] = {}
556
+ for attr_name in dir(module):
557
+ attr = getattr(module, attr_name)
558
+ action_name = getattr(attr, "__rappel_action_name__", None)
559
+ action_module = getattr(attr, "__rappel_action_module__", None)
560
+ if callable(attr) and action_name:
561
+ signature = inspect.signature(attr)
562
+ names[attr_name] = ActionDefinition(
563
+ action_name=action_name,
564
+ module_name=action_module or module.__name__,
565
+ signature=signature,
566
+ )
567
+ for entry in registry.entries():
568
+ if entry.module != module.__name__:
569
+ continue
570
+ func_name = entry.func.__name__
571
+ if func_name in names:
572
+ continue
573
+ signature = inspect.signature(entry.func)
574
+ names[func_name] = ActionDefinition(
575
+ action_name=entry.name,
576
+ module_name=entry.module,
577
+ signature=signature,
578
+ )
579
+ return names
580
+
581
+
582
+ def _discover_module_functions(module: Any) -> Set[str]:
583
+ """Discover all async function names defined in a module.
584
+
585
+ This is used to detect when users await functions in the same module
586
+ that are NOT decorated with @action.
587
+ """
588
+ function_names: Set[str] = set()
589
+ for attr_name in dir(module):
590
+ try:
591
+ attr = getattr(module, attr_name)
592
+ except AttributeError:
593
+ continue
594
+
595
+ # Only include functions defined in THIS module (not imported)
596
+ if not callable(attr):
597
+ continue
598
+ if not inspect.iscoroutinefunction(attr):
599
+ continue
600
+
601
+ # Check if the function is defined in this module
602
+ func_module = getattr(attr, "__module__", None)
603
+ if func_module == module.__name__:
604
+ function_names.add(attr_name)
605
+
606
+ return function_names
607
+
608
+
609
+ @dataclass
610
+ class ImportedName:
611
+ """Tracks an imported name and its source module."""
612
+
613
+ local_name: str # Name used in code (e.g., "sleep")
614
+ module: str # Source module (e.g., "asyncio")
615
+ original_name: str # Original name in source module (e.g., "sleep")
616
+
617
+
618
+ @dataclass
619
+ class ModelFieldDefinition:
620
+ """Definition of a field in a Pydantic model or dataclass."""
621
+
622
+ name: str
623
+ has_default: bool
624
+ default_value: Any = None # Only set if has_default is True
625
+
626
+
627
+ @dataclass
628
+ class ModelDefinition:
629
+ """Definition of a Pydantic model or dataclass that can be used in workflows.
630
+
631
+ These are data classes that can be instantiated in workflow code and will
632
+ be converted to dictionary expressions in the IR.
633
+ """
634
+
635
+ class_name: str
636
+ fields: Dict[str, ModelFieldDefinition]
637
+ is_pydantic: bool # True for Pydantic models, False for dataclasses
638
+
639
+
640
+ def _is_simple_pydantic_model(cls: type) -> bool:
641
+ """Check if a class is a simple Pydantic model without custom validators.
642
+
643
+ A simple Pydantic model:
644
+ - Inherits from pydantic.BaseModel
645
+ - Has no field_validator or model_validator decorators
646
+ - Has no custom __init__ method
647
+
648
+ Returns False if pydantic is not installed or cls is not a Pydantic model.
649
+ """
650
+ try:
651
+ from pydantic import BaseModel
652
+ except ImportError:
653
+ return False
654
+
655
+ if not isinstance(cls, type) or not issubclass(cls, BaseModel):
656
+ return False
657
+
658
+ # Check for validators - Pydantic v2 uses __pydantic_decorators__
659
+ decorators = getattr(cls, "__pydantic_decorators__", None)
660
+ if decorators is not None:
661
+ # Check for field validators
662
+ if hasattr(decorators, "field_validators") and decorators.field_validators:
663
+ return False
664
+ # Check for model validators
665
+ if hasattr(decorators, "model_validators") and decorators.model_validators:
666
+ return False
667
+
668
+ # Check for custom __init__ (not the one from BaseModel)
669
+ if "__init__" in cls.__dict__:
670
+ return False
671
+
672
+ return True
673
+
674
+
675
+ def _is_simple_dataclass(cls: type) -> bool:
676
+ """Check if a class is a simple dataclass without custom logic.
677
+
678
+ A simple dataclass:
679
+ - Is decorated with @dataclass
680
+ - Has no custom __init__ method (uses the generated one)
681
+ - Has no __post_init__ method
682
+
683
+ Returns False if cls is not a dataclass.
684
+ """
685
+ import dataclasses
686
+
687
+ if not dataclasses.is_dataclass(cls):
688
+ return False
689
+
690
+ # Check for __post_init__ which could have custom logic
691
+ if hasattr(cls, "__post_init__") and "__post_init__" in cls.__dict__:
692
+ return False
693
+
694
+ # Dataclasses generate __init__, so check if there's a custom one
695
+ # that overrides it (unlikely but possible)
696
+ # The dataclass decorator sets __init__ so we can't easily detect override
697
+ # We'll trust that dataclasses without __post_init__ are simple
698
+
699
+ return True
700
+
701
+
702
+ def _get_pydantic_model_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
703
+ """Extract field definitions from a Pydantic model."""
704
+ try:
705
+ from pydantic import BaseModel
706
+ from pydantic.fields import FieldInfo
707
+ except ImportError:
708
+ return {}
709
+
710
+ if not issubclass(cls, BaseModel):
711
+ return {}
712
+
713
+ fields: Dict[str, ModelFieldDefinition] = {}
714
+
715
+ # Pydantic v2 uses model_fields
716
+ model_fields = getattr(cls, "model_fields", {})
717
+ for field_name, field_info in model_fields.items():
718
+ has_default = False
719
+ default_value = None
720
+
721
+ if isinstance(field_info, FieldInfo):
722
+ # Check if field has a default value
723
+ # PydanticUndefined means no default
724
+ from pydantic_core import PydanticUndefined
725
+
726
+ if field_info.default is not PydanticUndefined:
727
+ has_default = True
728
+ default_value = field_info.default
729
+ elif field_info.default_factory is not None:
730
+ # We can't serialize factory functions, so treat as no default
731
+ has_default = False
732
+
733
+ fields[field_name] = ModelFieldDefinition(
734
+ name=field_name,
735
+ has_default=has_default,
736
+ default_value=default_value,
737
+ )
738
+
739
+ return fields
740
+
741
+
742
+ def _get_dataclass_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
743
+ """Extract field definitions from a dataclass."""
744
+ import dataclasses
745
+
746
+ if not dataclasses.is_dataclass(cls):
747
+ return {}
748
+
749
+ fields: Dict[str, ModelFieldDefinition] = {}
750
+ for field in dataclasses.fields(cls):
751
+ has_default = False
752
+ default_value = None
753
+
754
+ if field.default is not dataclasses.MISSING:
755
+ has_default = True
756
+ default_value = field.default
757
+ elif field.default_factory is not dataclasses.MISSING:
758
+ # We can't serialize factory functions, so treat as no default
759
+ has_default = False
760
+
761
+ fields[field.name] = ModelFieldDefinition(
762
+ name=field.name,
763
+ has_default=has_default,
764
+ default_value=default_value,
765
+ )
766
+
767
+ return fields
768
+
769
+
770
+ def _discover_model_definitions(module: Any) -> Dict[str, ModelDefinition]:
771
+ """Discover all Pydantic models and dataclasses that can be used in workflows.
772
+
773
+ Only discovers "simple" models without custom validators or __post_init__.
774
+ """
775
+ models: Dict[str, ModelDefinition] = {}
776
+
777
+ for attr_name in dir(module):
778
+ try:
779
+ attr = getattr(module, attr_name)
780
+ except AttributeError:
781
+ continue
782
+
783
+ if not isinstance(attr, type):
784
+ continue
785
+
786
+ # Check if this class is defined in this module or imported
787
+ # We want to include both, as models might be imported
788
+ if _is_simple_pydantic_model(attr):
789
+ fields = _get_pydantic_model_fields(attr)
790
+ models[attr_name] = ModelDefinition(
791
+ class_name=attr_name,
792
+ fields=fields,
793
+ is_pydantic=True,
794
+ )
795
+ elif _is_simple_dataclass(attr):
796
+ fields = _get_dataclass_fields(attr)
797
+ models[attr_name] = ModelDefinition(
798
+ class_name=attr_name,
799
+ fields=fields,
800
+ is_pydantic=False,
801
+ )
802
+
803
+ return models
804
+
805
+
806
+ def _discover_module_imports(module: Any) -> Dict[str, ImportedName]:
807
+ """Discover imports in a module by parsing its source.
808
+
809
+ Tracks imports like:
810
+ - from asyncio import sleep -> {"sleep": ImportedName("sleep", "asyncio", "sleep")}
811
+ - from asyncio import sleep as s -> {"s": ImportedName("s", "asyncio", "sleep")}
812
+ """
813
+ imported: Dict[str, ImportedName] = {}
814
+
815
+ try:
816
+ source = inspect.getsource(module)
817
+ tree = ast.parse(source)
818
+ except (OSError, TypeError):
819
+ # Can't get source (e.g., built-in module)
820
+ return imported
821
+
822
+ for node in ast.walk(tree):
823
+ if isinstance(node, ast.ImportFrom) and node.module:
824
+ for alias in node.names:
825
+ local_name = alias.asname if alias.asname else alias.name
826
+ imported[local_name] = ImportedName(
827
+ local_name=local_name,
828
+ module=node.module,
829
+ original_name=alias.name,
830
+ )
831
+
832
+ return imported
833
+
834
+
835
+ class IRBuilder(ast.NodeVisitor):
836
+ """Builds IR from Python AST with deep transformations."""
837
+
838
+ def __init__(
839
+ self,
840
+ action_defs: Dict[str, ActionDefinition],
841
+ ctx: TransformContext,
842
+ imported_names: Optional[Dict[str, ImportedName]] = None,
843
+ module_functions: Optional[Set[str]] = None,
844
+ model_defs: Optional[Dict[str, ModelDefinition]] = None,
845
+ module_globals: Optional[Mapping[str, Any]] = None,
846
+ instance_attrs: Optional[Dict[str, ast.expr]] = None,
847
+ ):
848
+ self._action_defs = action_defs
849
+ self._ctx = ctx
850
+ self._imported_names = imported_names or {}
851
+ self._module_functions = module_functions or set()
852
+ self._model_defs = model_defs or {}
853
+ self._module_globals = module_globals or {}
854
+ self._instance_attrs = instance_attrs or {}
855
+ self.function_def: Optional[ir.FunctionDef] = None
856
+ self._statements: List[ir.Statement] = []
857
+
858
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
859
+ """Visit a function definition (the workflow's run method)."""
860
+ inputs = self._collect_function_inputs(node)
861
+
862
+ # Create the function definition
863
+ self.function_def = ir.FunctionDef(
864
+ name=node.name,
865
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
866
+ span=_make_span(node),
867
+ )
868
+
869
+ # Visit the body - _visit_statement now returns a list
870
+ self._statements = []
871
+ for stmt in node.body:
872
+ ir_stmts = self._visit_statement(stmt)
873
+ self._statements.extend(ir_stmts)
874
+
875
+ # Set the body
876
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
877
+
878
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
879
+ """Visit an async function definition (the workflow's run method)."""
880
+ # Handle async the same way as sync for IR building
881
+ inputs = self._collect_function_inputs(node)
882
+
883
+ self.function_def = ir.FunctionDef(
884
+ name=node.name,
885
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
886
+ span=_make_span(node),
887
+ )
888
+
889
+ self._statements = []
890
+ for stmt in node.body:
891
+ ir_stmts = self._visit_statement(stmt)
892
+ self._statements.extend(ir_stmts)
893
+
894
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
895
+
896
+ def _visit_statement(self, node: ast.stmt) -> List[ir.Statement]:
897
+ """Convert a Python statement to IR Statement(s).
898
+
899
+ Returns a list because some transformations (like try block hoisting)
900
+ may expand a single Python statement into multiple IR statements.
901
+
902
+ Raises UnsupportedPatternError for unsupported statement types.
903
+ """
904
+ if isinstance(node, ast.Assign):
905
+ dict_expanded = self._expand_dict_comprehension_assignment(node)
906
+ if dict_expanded is not None:
907
+ return dict_expanded
908
+ expanded = self._expand_list_comprehension_assignment(node)
909
+ if expanded is not None:
910
+ return expanded
911
+ result = self._visit_assign(node)
912
+ return [result] if result else []
913
+ elif isinstance(node, ast.AnnAssign):
914
+ result = self._visit_ann_assign(node)
915
+ return [result] if result else []
916
+ elif isinstance(node, ast.Expr):
917
+ result = self._visit_expr_stmt(node)
918
+ return [result] if result else []
919
+ elif isinstance(node, ast.For):
920
+ return self._visit_for(node)
921
+ elif isinstance(node, ast.If):
922
+ return self._visit_if(node)
923
+ elif isinstance(node, ast.Try):
924
+ return self._visit_try(node)
925
+ elif isinstance(node, ast.Return):
926
+ return self._visit_return(node)
927
+ elif isinstance(node, ast.AugAssign):
928
+ return self._visit_aug_assign(node)
929
+ elif isinstance(node, ast.Pass):
930
+ # Pass statements are fine, they just don't produce IR
931
+ return []
932
+ elif isinstance(node, ast.Break):
933
+ return self._visit_break(node)
934
+ elif isinstance(node, ast.Continue):
935
+ return self._visit_continue(node)
936
+
937
+ # Check for unsupported statement types - this MUST raise for any
938
+ # unhandled statement to avoid silently dropping code
939
+ self._check_unsupported_statement(node)
940
+
941
+ def _check_unsupported_statement(self, node: ast.stmt) -> NoReturn:
942
+ """Check for unsupported statement types and raise descriptive errors.
943
+
944
+ This function ALWAYS raises an exception - it never returns normally.
945
+ Any statement type that reaches this function is either explicitly
946
+ unsupported (with a specific error message) or unhandled (with a
947
+ generic catch-all error). This ensures we never silently drop code.
948
+ """
949
+ line = getattr(node, "lineno", None)
950
+ col = getattr(node, "col_offset", None)
951
+
952
+ if isinstance(node, ast.While):
953
+ raise UnsupportedPatternError(
954
+ "While loops are not supported",
955
+ RECOMMENDATIONS["while_loop"],
956
+ line=line,
957
+ col=col,
958
+ )
959
+ elif isinstance(node, (ast.With, ast.AsyncWith)):
960
+ raise UnsupportedPatternError(
961
+ "Context managers (with statements) are not supported",
962
+ RECOMMENDATIONS["with_statement"],
963
+ line=line,
964
+ col=col,
965
+ )
966
+ elif isinstance(node, ast.Raise):
967
+ raise UnsupportedPatternError(
968
+ "The 'raise' statement is not supported",
969
+ RECOMMENDATIONS["raise_statement"],
970
+ line=line,
971
+ col=col,
972
+ )
973
+ elif isinstance(node, ast.Assert):
974
+ raise UnsupportedPatternError(
975
+ "Assert statements are not supported",
976
+ RECOMMENDATIONS["assert_statement"],
977
+ line=line,
978
+ col=col,
979
+ )
980
+ elif isinstance(node, ast.Delete):
981
+ raise UnsupportedPatternError(
982
+ "The 'del' statement is not supported",
983
+ RECOMMENDATIONS["delete"],
984
+ line=line,
985
+ col=col,
986
+ )
987
+ elif isinstance(node, ast.Global):
988
+ raise UnsupportedPatternError(
989
+ "Global statements are not supported",
990
+ RECOMMENDATIONS["global_statement"],
991
+ line=line,
992
+ col=col,
993
+ )
994
+ elif isinstance(node, ast.Nonlocal):
995
+ raise UnsupportedPatternError(
996
+ "Nonlocal statements are not supported",
997
+ RECOMMENDATIONS["nonlocal_statement"],
998
+ line=line,
999
+ col=col,
1000
+ )
1001
+ elif isinstance(node, (ast.Import, ast.ImportFrom)):
1002
+ raise UnsupportedPatternError(
1003
+ "Import statements inside workflow run() are not supported",
1004
+ RECOMMENDATIONS["import_statement"],
1005
+ line=line,
1006
+ col=col,
1007
+ )
1008
+ elif isinstance(node, ast.ClassDef):
1009
+ raise UnsupportedPatternError(
1010
+ "Class definitions inside workflow run() are not supported",
1011
+ RECOMMENDATIONS["class_def"],
1012
+ line=line,
1013
+ col=col,
1014
+ )
1015
+ elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
1016
+ raise UnsupportedPatternError(
1017
+ "Nested function definitions are not supported",
1018
+ RECOMMENDATIONS["function_def"],
1019
+ line=line,
1020
+ col=col,
1021
+ )
1022
+ elif hasattr(ast, "Match") and isinstance(node, ast.Match):
1023
+ raise UnsupportedPatternError(
1024
+ "Match statements are not supported",
1025
+ RECOMMENDATIONS["match"],
1026
+ line=line,
1027
+ col=col,
1028
+ )
1029
+ else:
1030
+ # Catch-all for any unhandled statement types.
1031
+ # This is critical to avoid silently dropping code.
1032
+ stmt_type = type(node).__name__
1033
+ raise UnsupportedPatternError(
1034
+ f"Unhandled statement type: {stmt_type}",
1035
+ RECOMMENDATIONS["unsupported_statement"],
1036
+ line=line,
1037
+ col=col,
1038
+ )
1039
+
1040
+ def _expand_list_comprehension_assignment(
1041
+ self, node: ast.Assign
1042
+ ) -> Optional[List[ir.Statement]]:
1043
+ """Expand a list comprehension assignment into loop-based statements.
1044
+
1045
+ Example:
1046
+ active_users = [user for user in users if user.active]
1047
+
1048
+ Becomes:
1049
+ active_users = []
1050
+ for user in users:
1051
+ if user.active:
1052
+ active_users = active_users + [user]
1053
+ """
1054
+ if not isinstance(node.value, ast.ListComp):
1055
+ return None
1056
+
1057
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
1058
+ line = getattr(node, "lineno", None)
1059
+ col = getattr(node, "col_offset", None)
1060
+ raise UnsupportedPatternError(
1061
+ "List comprehension assignments must target a single variable",
1062
+ "Assign the comprehension to a simple variable like `results = [x for x in items]`",
1063
+ line=line,
1064
+ col=col,
1065
+ )
1066
+
1067
+ listcomp = node.value
1068
+ if len(listcomp.generators) != 1:
1069
+ line = getattr(listcomp, "lineno", None)
1070
+ col = getattr(listcomp, "col_offset", None)
1071
+ raise UnsupportedPatternError(
1072
+ "List comprehensions with multiple generators are not supported",
1073
+ "Use nested for loops instead of combining multiple generators in one comprehension",
1074
+ line=line,
1075
+ col=col,
1076
+ )
1077
+
1078
+ gen = listcomp.generators[0]
1079
+ if gen.is_async:
1080
+ line = getattr(listcomp, "lineno", None)
1081
+ col = getattr(listcomp, "col_offset", None)
1082
+ raise UnsupportedPatternError(
1083
+ "Async list comprehensions are not supported",
1084
+ "Rewrite using an explicit async for loop",
1085
+ line=line,
1086
+ col=col,
1087
+ )
1088
+
1089
+ target_name = node.targets[0].id
1090
+
1091
+ # Initialize the accumulator list: active_users = []
1092
+ init_assign_ast = ast.Assign(
1093
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
1094
+ value=ast.List(elts=[], ctx=ast.Load()),
1095
+ type_comment=None,
1096
+ )
1097
+ ast.copy_location(init_assign_ast, node)
1098
+ ast.fix_missing_locations(init_assign_ast)
1099
+
1100
+ def _make_append_assignment(value_expr: ast.expr) -> ast.Assign:
1101
+ append_assign = ast.Assign(
1102
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
1103
+ value=ast.BinOp(
1104
+ left=ast.Name(id=target_name, ctx=ast.Load()),
1105
+ op=ast.Add(),
1106
+ right=ast.List(elts=[copy.deepcopy(value_expr)], ctx=ast.Load()),
1107
+ ),
1108
+ type_comment=None,
1109
+ )
1110
+ ast.copy_location(append_assign, node.value)
1111
+ ast.fix_missing_locations(append_assign)
1112
+ return append_assign
1113
+
1114
+ append_statements: List[ast.stmt] = []
1115
+ if isinstance(listcomp.elt, ast.IfExp):
1116
+ then_assign = _make_append_assignment(listcomp.elt.body)
1117
+ else_assign = _make_append_assignment(listcomp.elt.orelse)
1118
+ branch_if = ast.If(
1119
+ test=copy.deepcopy(listcomp.elt.test),
1120
+ body=[then_assign],
1121
+ orelse=[else_assign],
1122
+ )
1123
+ ast.copy_location(branch_if, listcomp.elt)
1124
+ ast.fix_missing_locations(branch_if)
1125
+ append_statements.append(branch_if)
1126
+ else:
1127
+ append_statements.append(_make_append_assignment(listcomp.elt))
1128
+
1129
+ loop_body: List[ast.stmt] = append_statements
1130
+ if gen.ifs:
1131
+ condition: ast.expr
1132
+ if len(gen.ifs) == 1:
1133
+ condition = copy.deepcopy(gen.ifs[0])
1134
+ else:
1135
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
1136
+ ast.copy_location(condition, gen.ifs[0])
1137
+ if_stmt = ast.If(test=condition, body=append_statements, orelse=[])
1138
+ ast.copy_location(if_stmt, gen.ifs[0])
1139
+ ast.fix_missing_locations(if_stmt)
1140
+ loop_body = [if_stmt]
1141
+
1142
+ loop_ast = ast.For(
1143
+ target=copy.deepcopy(gen.target),
1144
+ iter=copy.deepcopy(gen.iter),
1145
+ body=loop_body,
1146
+ orelse=[],
1147
+ type_comment=None,
1148
+ )
1149
+ ast.copy_location(loop_ast, node)
1150
+ ast.fix_missing_locations(loop_ast)
1151
+
1152
+ statements: List[ir.Statement] = []
1153
+ init_stmt = self._visit_assign(init_assign_ast)
1154
+ if init_stmt:
1155
+ statements.append(init_stmt)
1156
+ statements.extend(self._visit_for(loop_ast))
1157
+
1158
+ return statements
1159
+
1160
+ def _expand_dict_comprehension_assignment(
1161
+ self, node: ast.Assign
1162
+ ) -> Optional[List[ir.Statement]]:
1163
+ """Expand a dict comprehension assignment into loop-based statements."""
1164
+ if not isinstance(node.value, ast.DictComp):
1165
+ return None
1166
+
1167
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
1168
+ line = getattr(node, "lineno", None)
1169
+ col = getattr(node, "col_offset", None)
1170
+ raise UnsupportedPatternError(
1171
+ "Dict comprehension assignments must target a single variable",
1172
+ "Assign the comprehension to a simple variable like `result = {k: v for k, v in pairs}`",
1173
+ line=line,
1174
+ col=col,
1175
+ )
1176
+
1177
+ dictcomp = node.value
1178
+ if len(dictcomp.generators) != 1:
1179
+ line = getattr(dictcomp, "lineno", None)
1180
+ col = getattr(dictcomp, "col_offset", None)
1181
+ raise UnsupportedPatternError(
1182
+ "Dict comprehensions with multiple generators are not supported",
1183
+ "Use nested for loops instead of combining multiple generators in one comprehension",
1184
+ line=line,
1185
+ col=col,
1186
+ )
1187
+
1188
+ gen = dictcomp.generators[0]
1189
+ if gen.is_async:
1190
+ line = getattr(dictcomp, "lineno", None)
1191
+ col = getattr(dictcomp, "col_offset", None)
1192
+ raise UnsupportedPatternError(
1193
+ "Async dict comprehensions are not supported",
1194
+ "Rewrite using an explicit async for loop",
1195
+ line=line,
1196
+ col=col,
1197
+ )
1198
+
1199
+ target_name = node.targets[0].id
1200
+
1201
+ # Initialize accumulator: result = {}
1202
+ init_assign_ast = ast.Assign(
1203
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
1204
+ value=ast.Dict(keys=[], values=[]),
1205
+ type_comment=None,
1206
+ )
1207
+ ast.copy_location(init_assign_ast, node)
1208
+ ast.fix_missing_locations(init_assign_ast)
1209
+
1210
+ # result[key] = value
1211
+ subscript_target = ast.Subscript(
1212
+ value=ast.Name(id=target_name, ctx=ast.Load()),
1213
+ slice=copy.deepcopy(dictcomp.key),
1214
+ ctx=ast.Store(),
1215
+ )
1216
+ append_assign_ast = ast.Assign(
1217
+ targets=[subscript_target],
1218
+ value=copy.deepcopy(dictcomp.value),
1219
+ type_comment=None,
1220
+ )
1221
+ ast.copy_location(append_assign_ast, node.value)
1222
+ ast.fix_missing_locations(append_assign_ast)
1223
+
1224
+ loop_body: List[ast.stmt] = []
1225
+ if gen.ifs:
1226
+ condition: ast.expr
1227
+ if len(gen.ifs) == 1:
1228
+ condition = copy.deepcopy(gen.ifs[0])
1229
+ else:
1230
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
1231
+ ast.copy_location(condition, gen.ifs[0])
1232
+ if_stmt = ast.If(test=condition, body=[append_assign_ast], orelse=[])
1233
+ ast.copy_location(if_stmt, gen.ifs[0])
1234
+ ast.fix_missing_locations(if_stmt)
1235
+ loop_body.append(if_stmt)
1236
+ else:
1237
+ loop_body.append(append_assign_ast)
1238
+
1239
+ loop_ast = ast.For(
1240
+ target=copy.deepcopy(gen.target),
1241
+ iter=copy.deepcopy(gen.iter),
1242
+ body=loop_body,
1243
+ orelse=[],
1244
+ type_comment=None,
1245
+ )
1246
+ ast.copy_location(loop_ast, node)
1247
+ ast.fix_missing_locations(loop_ast)
1248
+
1249
+ statements: List[ir.Statement] = []
1250
+ init_stmt = self._visit_assign(init_assign_ast)
1251
+ if init_stmt:
1252
+ statements.append(init_stmt)
1253
+ statements.extend(self._visit_for(loop_ast))
1254
+
1255
+ return statements
1256
+
1257
+ def _visit_ann_assign(self, node: ast.AnnAssign) -> Optional[ir.Statement]:
1258
+ """Convert annotated assignment to IR when a value is present."""
1259
+ if node.value is None:
1260
+ return None
1261
+
1262
+ assign = ast.Assign(targets=[node.target], value=node.value, type_comment=None)
1263
+ ast.copy_location(assign, node)
1264
+ ast.fix_missing_locations(assign)
1265
+ return self._visit_assign(assign)
1266
+
1267
+ def _visit_assign(self, node: ast.Assign) -> Optional[ir.Statement]:
1268
+ """Convert assignment to IR.
1269
+
1270
+ All assignments with targets use the Assignment statement type.
1271
+ This provides uniform unpacking support for:
1272
+ - Action calls: a, b = @get_pair()
1273
+ - Parallel blocks: a, b = parallel: @x() @y()
1274
+ - Regular expressions: a, b = some_list
1275
+
1276
+ Raises UnsupportedPatternError for:
1277
+ - Constructor calls: x = MyClass(...)
1278
+ - Non-action await: x = await some_func()
1279
+ """
1280
+ stmt = ir.Statement(span=_make_span(node))
1281
+ targets = self._get_assign_targets(node.targets)
1282
+
1283
+ # Check for Pydantic model or dataclass constructor calls
1284
+ # These are converted to dict expressions
1285
+ model_name = self._is_model_constructor(node.value)
1286
+ if model_name and isinstance(node.value, ast.Call):
1287
+ value_expr = self._convert_model_constructor_to_dict(node.value, model_name)
1288
+ assign = ir.Assignment(targets=targets, value=value_expr)
1289
+ stmt.assignment.CopyFrom(assign)
1290
+ return stmt
1291
+
1292
+ # Check for constructor calls in assignment (e.g., x = MyModel(...))
1293
+ # This must come AFTER the model constructor check since models are allowed
1294
+ self._check_constructor_in_assignment(node.value)
1295
+
1296
+ # Check for asyncio.gather() - convert to parallel or spread expression
1297
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1298
+ gather_call = node.value.value
1299
+ if self._is_asyncio_gather_call(gather_call):
1300
+ gather_result = self._convert_asyncio_gather(gather_call)
1301
+ if gather_result is not None:
1302
+ if isinstance(gather_result, ir.ParallelExpr):
1303
+ value = ir.Expr(parallel_expr=gather_result, span=_make_span(node))
1304
+ else:
1305
+ # SpreadExpr
1306
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1307
+ assign = ir.Assignment(targets=targets, value=value)
1308
+ stmt.assignment.CopyFrom(assign)
1309
+ return stmt
1310
+
1311
+ # Check if this is an action call - wrap in Assignment for uniform unpacking
1312
+ action_call = self._extract_action_call(node.value)
1313
+ if action_call:
1314
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1315
+ assign = ir.Assignment(targets=targets, value=value)
1316
+ stmt.assignment.CopyFrom(assign)
1317
+ return stmt
1318
+
1319
+ # Regular assignment (variables, literals, expressions)
1320
+ value_expr = self._expr_to_ir_with_model_coercion(node.value)
1321
+ if value_expr:
1322
+ assign = ir.Assignment(targets=targets, value=value_expr)
1323
+ stmt.assignment.CopyFrom(assign)
1324
+ return stmt
1325
+
1326
+ return None
1327
+
1328
+ def _visit_expr_stmt(self, node: ast.Expr) -> Optional[ir.Statement]:
1329
+ """Convert expression statement to IR (side effect only, no assignment)."""
1330
+ stmt = ir.Statement(span=_make_span(node))
1331
+
1332
+ # Check for asyncio.gather() - convert to parallel block statement (side effect)
1333
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1334
+ gather_call = node.value.value
1335
+ if self._is_asyncio_gather_call(gather_call):
1336
+ gather_result = self._convert_asyncio_gather(gather_call)
1337
+ if gather_result is not None:
1338
+ if isinstance(gather_result, ir.ParallelExpr):
1339
+ # Side effect only - use ParallelBlock statement
1340
+ parallel = ir.ParallelBlock()
1341
+ parallel.calls.extend(gather_result.calls)
1342
+ stmt.parallel_block.CopyFrom(parallel)
1343
+ return stmt
1344
+ else:
1345
+ # SpreadExpr as side effect - wrap in assignment with no targets
1346
+ # This handles: await asyncio.gather(*[action(x) for x in items])
1347
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1348
+ assign = ir.Assignment(targets=[], value=value)
1349
+ stmt.assignment.CopyFrom(assign)
1350
+ return stmt
1351
+
1352
+ # Check if this is an action call (side effect only)
1353
+ action_call = self._extract_action_call(node.value)
1354
+ if action_call:
1355
+ stmt.action_call.CopyFrom(action_call)
1356
+ return stmt
1357
+
1358
+ # Convert list.append(x) to list = list + [x]
1359
+ # This makes the mutation explicit so data flows correctly through the DAG
1360
+ if isinstance(node.value, ast.Call):
1361
+ call = node.value
1362
+ if (
1363
+ isinstance(call.func, ast.Attribute)
1364
+ and call.func.attr == "append"
1365
+ and isinstance(call.func.value, ast.Name)
1366
+ and len(call.args) == 1
1367
+ ):
1368
+ list_name = call.func.value.id
1369
+ append_value = call.args[0]
1370
+ # Create: list = list + [value]
1371
+ list_var = ir.Expr(variable=ir.Variable(name=list_name), span=_make_span(node))
1372
+ value_expr = self._expr_to_ir_with_model_coercion(append_value)
1373
+ if value_expr:
1374
+ # Create [value] as a list literal
1375
+ list_literal = ir.Expr(
1376
+ list=ir.ListExpr(elements=[value_expr]), span=_make_span(node)
1377
+ )
1378
+ # Create list + [value]
1379
+ concat_expr = ir.Expr(
1380
+ binary_op=ir.BinaryOp(
1381
+ op=ir.BinaryOperator.BINARY_OP_ADD, left=list_var, right=list_literal
1382
+ ),
1383
+ span=_make_span(node),
1384
+ )
1385
+ assign = ir.Assignment(targets=[list_name], value=concat_expr)
1386
+ stmt.assignment.CopyFrom(assign)
1387
+ return stmt
1388
+
1389
+ # Regular expression
1390
+ expr = self._expr_to_ir_with_model_coercion(node.value)
1391
+ if expr:
1392
+ stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=expr))
1393
+ return stmt
1394
+
1395
+ return None
1396
+
1397
+ def _visit_for(self, node: ast.For) -> List[ir.Statement]:
1398
+ """Convert for loop to IR.
1399
+
1400
+ The loop body is emitted as a full block so it can contain multiple
1401
+ statements/calls and early `return`.
1402
+ """
1403
+ # Get loop variables
1404
+ loop_vars: List[str] = []
1405
+ if isinstance(node.target, ast.Name):
1406
+ loop_vars.append(node.target.id)
1407
+ elif isinstance(node.target, ast.Tuple):
1408
+ for elt in node.target.elts:
1409
+ if isinstance(elt, ast.Name):
1410
+ loop_vars.append(elt.id)
1411
+
1412
+ # Get iterable
1413
+ iterable = self._expr_to_ir_with_model_coercion(node.iter)
1414
+ if not iterable:
1415
+ return []
1416
+
1417
+ # Build body statements (recursively transforms nested structures)
1418
+ body_stmts: List[ir.Statement] = []
1419
+ for body_node in node.body:
1420
+ stmts = self._visit_statement(body_node)
1421
+ body_stmts.extend(stmts)
1422
+
1423
+ stmt = ir.Statement(span=_make_span(node))
1424
+ for_loop = ir.ForLoop(
1425
+ loop_vars=loop_vars,
1426
+ iterable=iterable,
1427
+ block_body=ir.Block(statements=body_stmts, span=_make_span(node)),
1428
+ )
1429
+ stmt.for_loop.CopyFrom(for_loop)
1430
+ return [stmt]
1431
+
1432
+ def _detect_accumulator_targets(
1433
+ self, stmts: List[ir.Statement], in_scope_vars: set
1434
+ ) -> List[str]:
1435
+ """Detect out-of-scope variable modifications in for loop body.
1436
+
1437
+ Scans statements for patterns that modify variables defined outside the loop.
1438
+ Returns a list of accumulator variable names that should be set as targets.
1439
+
1440
+ Supported patterns:
1441
+ 1. List append: results.append(value) -> "results"
1442
+ 2. Dict subscript: result[key] = value -> "result"
1443
+ 3. List/set update methods: results.extend(...), results.add(...) -> "results"
1444
+
1445
+ Note: Patterns like `results = results + [x]` and `count = count + 1` create
1446
+ new assignments which are tracked via in_scope_vars and don't need special
1447
+ detection here - they're handled by the regular assignment target logic.
1448
+ """
1449
+ accumulators: List[str] = []
1450
+ seen: set = set()
1451
+
1452
+ for stmt in stmts:
1453
+ var_name = self._extract_accumulator_from_stmt(stmt, in_scope_vars)
1454
+ if var_name and var_name not in seen:
1455
+ accumulators.append(var_name)
1456
+ seen.add(var_name)
1457
+
1458
+ # Check conditionals for accumulator targets in branch bodies
1459
+ if stmt.HasField("conditional"):
1460
+ cond = stmt.conditional
1461
+ branch_blocks: list[ir.Block] = []
1462
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1463
+ branch_blocks.append(cond.if_branch.block_body)
1464
+ for branch in cond.elif_branches:
1465
+ if branch.HasField("block_body"):
1466
+ branch_blocks.append(branch.block_body)
1467
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1468
+ branch_blocks.append(cond.else_branch.block_body)
1469
+
1470
+ for block in branch_blocks:
1471
+ for var in self._detect_accumulator_targets(
1472
+ list(block.statements), in_scope_vars
1473
+ ):
1474
+ if var not in seen:
1475
+ accumulators.append(var)
1476
+ seen.add(var)
1477
+
1478
+ return accumulators
1479
+
1480
+ def _extract_accumulator_from_stmt(
1481
+ self, stmt: ir.Statement, in_scope_vars: set
1482
+ ) -> Optional[str]:
1483
+ """Extract accumulator variable name from a single statement.
1484
+
1485
+ Returns the variable name if this statement modifies an out-of-scope variable,
1486
+ None otherwise.
1487
+ """
1488
+ # Pattern 1: Method calls like list.append(), dict.update(), set.add()
1489
+ if stmt.HasField("expr_stmt"):
1490
+ expr = stmt.expr_stmt.expr
1491
+ if expr.HasField("function_call"):
1492
+ fn_name = expr.function_call.name
1493
+ # Check for mutating method calls: x.append, x.extend, x.add, x.update, etc.
1494
+ mutating_methods = {
1495
+ ".append",
1496
+ ".extend",
1497
+ ".add",
1498
+ ".update",
1499
+ ".insert",
1500
+ ".pop",
1501
+ ".remove",
1502
+ ".clear",
1503
+ }
1504
+ for method in mutating_methods:
1505
+ if fn_name.endswith(method):
1506
+ var_name = fn_name[: len(fn_name) - len(method)]
1507
+ # Only return if it's an out-of-scope variable
1508
+ if var_name and var_name not in in_scope_vars:
1509
+ return var_name
1510
+
1511
+ # Pattern 2: Subscript assignment like dict[key] = value
1512
+ if stmt.HasField("assignment"):
1513
+ for target in stmt.assignment.targets:
1514
+ # Check if target is a subscript pattern (contains '[')
1515
+ if "[" in target:
1516
+ # Extract base variable name (before '[')
1517
+ var_name = target.split("[")[0]
1518
+ if var_name and var_name not in in_scope_vars:
1519
+ return var_name
1520
+
1521
+ # Pattern 3: Self-referential assignment like x = x + [y]
1522
+ # The target variable is used on the RHS, so it must come from outside.
1523
+ # Note: We don't check in_scope_vars here because the assignment itself
1524
+ # would have added the target to in_scope_vars, but it still needs its
1525
+ # previous value from outside the loop body.
1526
+ if stmt.HasField("assignment"):
1527
+ assign = stmt.assignment
1528
+ rhs_vars = self._collect_variables_from_expr(assign.value)
1529
+ for target in assign.targets:
1530
+ if target in rhs_vars:
1531
+ return target
1532
+
1533
+ return None
1534
+
1535
+ def _collect_variables_from_expr(self, expr: ir.Expr) -> set:
1536
+ """Recursively collect all variable names used in an expression."""
1537
+ vars_found: set = set()
1538
+
1539
+ if expr.HasField("variable"):
1540
+ vars_found.add(expr.variable.name)
1541
+ elif expr.HasField("binary_op"):
1542
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.left))
1543
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.right))
1544
+ elif expr.HasField("unary_op"):
1545
+ vars_found.update(self._collect_variables_from_expr(expr.unary_op.operand))
1546
+ elif expr.HasField("list"):
1547
+ for elem in expr.list.elements:
1548
+ vars_found.update(self._collect_variables_from_expr(elem))
1549
+ elif expr.HasField("dict"):
1550
+ for entry in expr.dict.entries:
1551
+ vars_found.update(self._collect_variables_from_expr(entry.key))
1552
+ vars_found.update(self._collect_variables_from_expr(entry.value))
1553
+ elif expr.HasField("index"):
1554
+ vars_found.update(self._collect_variables_from_expr(expr.index.value))
1555
+ vars_found.update(self._collect_variables_from_expr(expr.index.index))
1556
+ elif expr.HasField("dot"):
1557
+ vars_found.update(self._collect_variables_from_expr(expr.dot.object))
1558
+ elif expr.HasField("function_call"):
1559
+ for kwarg in expr.function_call.kwargs:
1560
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1561
+ elif expr.HasField("action_call"):
1562
+ for kwarg in expr.action_call.kwargs:
1563
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1564
+
1565
+ return vars_found
1566
+
1567
+ def _visit_if(self, node: ast.If) -> List[ir.Statement]:
1568
+ """Convert if statement to IR.
1569
+
1570
+ Normalizes patterns like:
1571
+ if await some_action(...):
1572
+ ...
1573
+ into:
1574
+ __if_cond_n__ = await some_action(...)
1575
+ if __if_cond_n__:
1576
+ ...
1577
+ """
1578
+
1579
+ def normalize_condition(test: ast.expr) -> tuple[List[ir.Statement], Optional[ir.Expr]]:
1580
+ action_call = self._extract_action_call(test)
1581
+ if action_call is None:
1582
+ return ([], self._expr_to_ir_with_model_coercion(test))
1583
+
1584
+ if not isinstance(test, ast.Await):
1585
+ line = getattr(test, "lineno", None)
1586
+ col = getattr(test, "col_offset", None)
1587
+ raise UnsupportedPatternError(
1588
+ "Action calls inside boolean expressions are not supported in if conditions",
1589
+ "Assign the awaited action result to a variable, then use the variable in the if condition.",
1590
+ line=line,
1591
+ col=col,
1592
+ )
1593
+
1594
+ cond_var = self._ctx.next_implicit_fn_name(prefix="if_cond")
1595
+ assign_stmt = ir.Statement(span=_make_span(test))
1596
+ assign_stmt.assignment.CopyFrom(
1597
+ ir.Assignment(
1598
+ targets=[cond_var],
1599
+ value=ir.Expr(action_call=action_call, span=_make_span(test)),
1600
+ )
1601
+ )
1602
+ cond_expr = ir.Expr(variable=ir.Variable(name=cond_var), span=_make_span(test))
1603
+ return ([assign_stmt], cond_expr)
1604
+
1605
+ def visit_body(nodes: list[ast.stmt]) -> List[ir.Statement]:
1606
+ stmts: List[ir.Statement] = []
1607
+ for body_node in nodes:
1608
+ stmts.extend(self._visit_statement(body_node))
1609
+ return stmts
1610
+
1611
+ # Collect if/elif branches as (test_expr, body_nodes)
1612
+ branches: list[tuple[ast.expr, list[ast.stmt], ast.AST]] = [(node.test, node.body, node)]
1613
+ current = node
1614
+ while current.orelse and len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If):
1615
+ elif_node = current.orelse[0]
1616
+ branches.append((elif_node.test, elif_node.body, elif_node))
1617
+ current = elif_node
1618
+
1619
+ else_nodes = current.orelse
1620
+
1621
+ normalized: list[
1622
+ tuple[List[ir.Statement], Optional[ir.Expr], List[ir.Statement], ast.AST]
1623
+ ] = []
1624
+ for test_expr, body_nodes, span_node in branches:
1625
+ prefix, cond = normalize_condition(test_expr)
1626
+ normalized.append((prefix, cond, visit_body(body_nodes), span_node))
1627
+
1628
+ else_body = visit_body(else_nodes) if else_nodes else []
1629
+
1630
+ # If any non-first branch needs normalization, preserve Python semantics by nesting.
1631
+ requires_nested = any(prefix for prefix, _, _, _ in normalized[1:])
1632
+
1633
+ def build_conditional_stmt(
1634
+ condition: ir.Expr,
1635
+ then_body: List[ir.Statement],
1636
+ else_body_statements: List[ir.Statement],
1637
+ span_node: ast.AST,
1638
+ ) -> ir.Statement:
1639
+ conditional_stmt = ir.Statement(span=_make_span(span_node))
1640
+ if_branch = ir.IfBranch(
1641
+ condition=condition,
1642
+ block_body=ir.Block(statements=then_body, span=_make_span(span_node)),
1643
+ span=_make_span(span_node),
1644
+ )
1645
+ conditional = ir.Conditional(if_branch=if_branch)
1646
+ if else_body_statements:
1647
+ else_branch = ir.ElseBranch(
1648
+ block_body=ir.Block(
1649
+ statements=else_body_statements,
1650
+ span=_make_span(span_node),
1651
+ ),
1652
+ span=_make_span(span_node),
1653
+ )
1654
+ conditional.else_branch.CopyFrom(else_branch)
1655
+ conditional_stmt.conditional.CopyFrom(conditional)
1656
+ return conditional_stmt
1657
+
1658
+ if requires_nested:
1659
+ nested_else: List[ir.Statement] = else_body
1660
+ for prefix, cond, then_body, span_node in reversed(normalized):
1661
+ if cond is None:
1662
+ continue
1663
+ nested_if_stmt = build_conditional_stmt(
1664
+ condition=cond,
1665
+ then_body=then_body,
1666
+ else_body_statements=nested_else,
1667
+ span_node=span_node,
1668
+ )
1669
+ nested_else = [*prefix, nested_if_stmt]
1670
+ return nested_else
1671
+
1672
+ # Flat conditional with elif/else (original behavior), plus optional prefix for the if guard.
1673
+ if_prefix, if_condition, if_body, if_span_node = normalized[0]
1674
+ if if_condition is None:
1675
+ return []
1676
+
1677
+ conditional_stmt = ir.Statement(span=_make_span(if_span_node))
1678
+ if_branch = ir.IfBranch(
1679
+ condition=if_condition,
1680
+ block_body=ir.Block(statements=if_body, span=_make_span(if_span_node)),
1681
+ span=_make_span(if_span_node),
1682
+ )
1683
+ conditional = ir.Conditional(if_branch=if_branch)
1684
+
1685
+ for _, elif_condition, elif_body, elif_span_node in normalized[1:]:
1686
+ if elif_condition is None:
1687
+ continue
1688
+ elif_branch = ir.ElifBranch(
1689
+ condition=elif_condition,
1690
+ block_body=ir.Block(statements=elif_body, span=_make_span(elif_span_node)),
1691
+ span=_make_span(elif_span_node),
1692
+ )
1693
+ conditional.elif_branches.append(elif_branch)
1694
+
1695
+ if else_body:
1696
+ else_branch = ir.ElseBranch(
1697
+ block_body=ir.Block(statements=else_body, span=_make_span(if_span_node)),
1698
+ span=_make_span(if_span_node),
1699
+ )
1700
+ conditional.else_branch.CopyFrom(else_branch)
1701
+
1702
+ conditional_stmt.conditional.CopyFrom(conditional)
1703
+ return [*if_prefix, conditional_stmt]
1704
+
1705
+ def _collect_assigned_vars(self, stmts: List[ir.Statement]) -> set:
1706
+ """Collect all variable names assigned in a list of statements."""
1707
+ assigned = set()
1708
+ for stmt in stmts:
1709
+ if stmt.HasField("assignment"):
1710
+ assigned.update(stmt.assignment.targets)
1711
+ return assigned
1712
+
1713
+ def _collect_assigned_vars_in_order(self, stmts: List[ir.Statement]) -> list[str]:
1714
+ """Collect assigned variable names in statement order (deduplicated)."""
1715
+ assigned: list[str] = []
1716
+ seen: set[str] = set()
1717
+
1718
+ for stmt in stmts:
1719
+ if stmt.HasField("assignment"):
1720
+ for target in stmt.assignment.targets:
1721
+ if target not in seen:
1722
+ seen.add(target)
1723
+ assigned.append(target)
1724
+
1725
+ if stmt.HasField("conditional"):
1726
+ cond = stmt.conditional
1727
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1728
+ for target in self._collect_assigned_vars_in_order(
1729
+ list(cond.if_branch.block_body.statements)
1730
+ ):
1731
+ if target not in seen:
1732
+ seen.add(target)
1733
+ assigned.append(target)
1734
+ for elif_branch in cond.elif_branches:
1735
+ if elif_branch.HasField("block_body"):
1736
+ for target in self._collect_assigned_vars_in_order(
1737
+ list(elif_branch.block_body.statements)
1738
+ ):
1739
+ if target not in seen:
1740
+ seen.add(target)
1741
+ assigned.append(target)
1742
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1743
+ for target in self._collect_assigned_vars_in_order(
1744
+ list(cond.else_branch.block_body.statements)
1745
+ ):
1746
+ if target not in seen:
1747
+ seen.add(target)
1748
+ assigned.append(target)
1749
+
1750
+ if stmt.HasField("for_loop") and stmt.for_loop.HasField("block_body"):
1751
+ for target in self._collect_assigned_vars_in_order(
1752
+ list(stmt.for_loop.block_body.statements)
1753
+ ):
1754
+ if target not in seen:
1755
+ seen.add(target)
1756
+ assigned.append(target)
1757
+
1758
+ if stmt.HasField("try_except"):
1759
+ try_block = stmt.try_except.try_block
1760
+ if try_block.HasField("span"):
1761
+ for target in self._collect_assigned_vars_in_order(list(try_block.statements)):
1762
+ if target not in seen:
1763
+ seen.add(target)
1764
+ assigned.append(target)
1765
+ for handler in stmt.try_except.handlers:
1766
+ if handler.HasField("block_body"):
1767
+ for target in self._collect_assigned_vars_in_order(
1768
+ list(handler.block_body.statements)
1769
+ ):
1770
+ if target not in seen:
1771
+ seen.add(target)
1772
+ assigned.append(target)
1773
+
1774
+ return assigned
1775
+
1776
+ def _collect_variables_from_block(self, block: ir.Block) -> list[str]:
1777
+ return self._collect_variables_from_statements(list(block.statements))
1778
+
1779
+ def _collect_variables_from_statements(self, stmts: List[ir.Statement]) -> list[str]:
1780
+ """Collect variable references from statements in encounter order."""
1781
+ vars_found: list[str] = []
1782
+ seen: set[str] = set()
1783
+
1784
+ for stmt in stmts:
1785
+ if stmt.HasField("assignment") and stmt.assignment.HasField("value"):
1786
+ for var in self._collect_variables_from_expr(stmt.assignment.value):
1787
+ if var not in seen:
1788
+ seen.add(var)
1789
+ vars_found.append(var)
1790
+
1791
+ if stmt.HasField("return_stmt") and stmt.return_stmt.HasField("value"):
1792
+ for var in self._collect_variables_from_expr(stmt.return_stmt.value):
1793
+ if var not in seen:
1794
+ seen.add(var)
1795
+ vars_found.append(var)
1796
+
1797
+ if stmt.HasField("action_call"):
1798
+ expr = ir.Expr(action_call=stmt.action_call, span=stmt.span)
1799
+ for var in self._collect_variables_from_expr(expr):
1800
+ if var not in seen:
1801
+ seen.add(var)
1802
+ vars_found.append(var)
1803
+
1804
+ if stmt.HasField("expr_stmt"):
1805
+ for var in self._collect_variables_from_expr(stmt.expr_stmt.expr):
1806
+ if var not in seen:
1807
+ seen.add(var)
1808
+ vars_found.append(var)
1809
+
1810
+ if stmt.HasField("conditional"):
1811
+ cond = stmt.conditional
1812
+ if cond.HasField("if_branch"):
1813
+ if cond.if_branch.HasField("condition"):
1814
+ for var in self._collect_variables_from_expr(cond.if_branch.condition):
1815
+ if var not in seen:
1816
+ seen.add(var)
1817
+ vars_found.append(var)
1818
+ if cond.if_branch.HasField("block_body"):
1819
+ for var in self._collect_variables_from_block(cond.if_branch.block_body):
1820
+ if var not in seen:
1821
+ seen.add(var)
1822
+ vars_found.append(var)
1823
+ for elif_branch in cond.elif_branches:
1824
+ if elif_branch.HasField("condition"):
1825
+ for var in self._collect_variables_from_expr(elif_branch.condition):
1826
+ if var not in seen:
1827
+ seen.add(var)
1828
+ vars_found.append(var)
1829
+ if elif_branch.HasField("block_body"):
1830
+ for var in self._collect_variables_from_block(elif_branch.block_body):
1831
+ if var not in seen:
1832
+ seen.add(var)
1833
+ vars_found.append(var)
1834
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1835
+ for var in self._collect_variables_from_block(cond.else_branch.block_body):
1836
+ if var not in seen:
1837
+ seen.add(var)
1838
+ vars_found.append(var)
1839
+
1840
+ if stmt.HasField("for_loop"):
1841
+ fl = stmt.for_loop
1842
+ if fl.HasField("iterable"):
1843
+ for var in self._collect_variables_from_expr(fl.iterable):
1844
+ if var not in seen:
1845
+ seen.add(var)
1846
+ vars_found.append(var)
1847
+ if fl.HasField("block_body"):
1848
+ for var in self._collect_variables_from_block(fl.block_body):
1849
+ if var not in seen:
1850
+ seen.add(var)
1851
+ vars_found.append(var)
1852
+
1853
+ if stmt.HasField("try_except"):
1854
+ te = stmt.try_except
1855
+ if te.HasField("try_block"):
1856
+ for var in self._collect_variables_from_block(te.try_block):
1857
+ if var not in seen:
1858
+ seen.add(var)
1859
+ vars_found.append(var)
1860
+ for handler in te.handlers:
1861
+ if handler.HasField("block_body"):
1862
+ for var in self._collect_variables_from_block(handler.block_body):
1863
+ if var not in seen:
1864
+ seen.add(var)
1865
+ vars_found.append(var)
1866
+
1867
+ if stmt.HasField("parallel_block"):
1868
+ for call in stmt.parallel_block.calls:
1869
+ if call.HasField("action"):
1870
+ for kwarg in call.action.kwargs:
1871
+ for var in self._collect_variables_from_expr(kwarg.value):
1872
+ if var not in seen:
1873
+ seen.add(var)
1874
+ vars_found.append(var)
1875
+ elif call.HasField("function"):
1876
+ for kwarg in call.function.kwargs:
1877
+ for var in self._collect_variables_from_expr(kwarg.value):
1878
+ if var not in seen:
1879
+ seen.add(var)
1880
+ vars_found.append(var)
1881
+
1882
+ if stmt.HasField("spread_action"):
1883
+ spread = stmt.spread_action
1884
+ if spread.HasField("collection"):
1885
+ for var in self._collect_variables_from_expr(spread.collection):
1886
+ if var not in seen:
1887
+ seen.add(var)
1888
+ vars_found.append(var)
1889
+ if spread.HasField("action"):
1890
+ for kwarg in spread.action.kwargs:
1891
+ for var in self._collect_variables_from_expr(kwarg.value):
1892
+ if var not in seen:
1893
+ seen.add(var)
1894
+ vars_found.append(var)
1895
+
1896
+ return vars_found
1897
+
1898
+ def _visit_try(self, node: ast.Try) -> List[ir.Statement]:
1899
+ """Convert try/except to IR with full block bodies."""
1900
+ # Build try body statements (recursively transforms nested structures)
1901
+ try_body: List[ir.Statement] = []
1902
+ for body_node in node.body:
1903
+ stmts = self._visit_statement(body_node)
1904
+ try_body.extend(stmts)
1905
+
1906
+ # Build exception handlers (with wrapping if needed)
1907
+ handlers: List[ir.ExceptHandler] = []
1908
+ for handler in node.handlers:
1909
+ exception_types: List[str] = []
1910
+ if handler.type:
1911
+ if isinstance(handler.type, ast.Name):
1912
+ exception_types.append(handler.type.id)
1913
+ elif isinstance(handler.type, ast.Tuple):
1914
+ for elt in handler.type.elts:
1915
+ if isinstance(elt, ast.Name):
1916
+ exception_types.append(elt.id)
1917
+
1918
+ # Build handler body (recursively transforms nested structures)
1919
+ handler_body: List[ir.Statement] = []
1920
+ for handler_node in handler.body:
1921
+ stmts = self._visit_statement(handler_node)
1922
+ handler_body.extend(stmts)
1923
+
1924
+ except_handler = ir.ExceptHandler(
1925
+ exception_types=exception_types,
1926
+ block_body=ir.Block(statements=handler_body, span=_make_span(handler)),
1927
+ span=_make_span(handler),
1928
+ )
1929
+ if handler.name:
1930
+ except_handler.exception_var = handler.name
1931
+ handlers.append(except_handler)
1932
+
1933
+ # Build the try/except statement
1934
+ try_stmt = ir.Statement(span=_make_span(node))
1935
+ try_except = ir.TryExcept(
1936
+ handlers=handlers,
1937
+ try_block=ir.Block(statements=try_body, span=_make_span(node)),
1938
+ )
1939
+ try_stmt.try_except.CopyFrom(try_except)
1940
+
1941
+ return [try_stmt]
1942
+
1943
+ def _count_calls(self, stmts: List[ir.Statement]) -> int:
1944
+ """Count action calls and function calls in statements.
1945
+
1946
+ Both action calls and function calls (including synthetic functions)
1947
+ count toward the limit of one call per control flow body.
1948
+ """
1949
+ count = 0
1950
+ for stmt in stmts:
1951
+ if stmt.HasField("action_call"):
1952
+ count += 1
1953
+ elif stmt.HasField("assignment"):
1954
+ # Check if assignment value is an action call or function call
1955
+ if stmt.assignment.value.HasField("action_call"):
1956
+ count += 1
1957
+ elif stmt.assignment.value.HasField("function_call"):
1958
+ count += 1
1959
+ elif stmt.HasField("expr_stmt"):
1960
+ # Check if expression is a function call
1961
+ if stmt.expr_stmt.expr.HasField("function_call"):
1962
+ count += 1
1963
+ return count
1964
+
1965
+ def _wrap_body_as_function(
1966
+ self,
1967
+ body: List[ir.Statement],
1968
+ prefix: str,
1969
+ node: ast.AST,
1970
+ inputs: Optional[List[str]] = None,
1971
+ modified_vars: Optional[List[str]] = None,
1972
+ ) -> List[ir.Statement]:
1973
+ """Wrap a body with multiple calls into a synthetic function.
1974
+
1975
+ Args:
1976
+ body: The statements to wrap
1977
+ prefix: Name prefix for the synthetic function
1978
+ node: AST node for span information
1979
+ inputs: Variables to pass as inputs (e.g., loop variables)
1980
+ modified_vars: Out-of-scope variables modified in the body.
1981
+ These are added as inputs AND returned as outputs,
1982
+ enabling functional transformation of external state.
1983
+
1984
+ Returns a list containing a single function call statement (or assignment
1985
+ if modified_vars are present).
1986
+ """
1987
+ fn_name = self._ctx.next_implicit_fn_name(prefix)
1988
+ fn_inputs = list(inputs or [])
1989
+
1990
+ # Add modified variables as inputs (they need to be passed in)
1991
+ modified_vars = modified_vars or []
1992
+ for var in modified_vars:
1993
+ if var not in fn_inputs:
1994
+ fn_inputs.append(var)
1995
+
1996
+ # If there are modified variables, add a return statement for them
1997
+ wrapped_body = list(body)
1998
+ if modified_vars:
1999
+ # Create return statement: return (var1, var2, ...) or return var1
2000
+ if len(modified_vars) == 1:
2001
+ return_expr = ir.Expr(
2002
+ variable=ir.Variable(name=modified_vars[0]),
2003
+ span=_make_span(node),
2004
+ )
2005
+ else:
2006
+ # Return as list (tuples are represented as lists in IR)
2007
+ return_expr = ir.Expr(
2008
+ list=ir.ListExpr(
2009
+ elements=[ir.Expr(variable=ir.Variable(name=var)) for var in modified_vars]
2010
+ ),
2011
+ span=_make_span(node),
2012
+ )
2013
+ return_stmt = ir.Statement(span=_make_span(node))
2014
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=return_expr))
2015
+ wrapped_body.append(return_stmt)
2016
+
2017
+ # Create the synthetic function
2018
+ implicit_fn = ir.FunctionDef(
2019
+ name=fn_name,
2020
+ io=ir.IoDecl(inputs=fn_inputs, outputs=modified_vars),
2021
+ body=ir.Block(statements=wrapped_body),
2022
+ span=_make_span(node),
2023
+ )
2024
+ self._ctx.implicit_functions.append(implicit_fn)
2025
+
2026
+ # Create a function call expression
2027
+ kwargs = [
2028
+ ir.Kwarg(name=var, value=ir.Expr(variable=ir.Variable(name=var))) for var in fn_inputs
2029
+ ]
2030
+ fn_call_expr = ir.Expr(
2031
+ function_call=ir.FunctionCall(name=fn_name, kwargs=kwargs),
2032
+ span=_make_span(node),
2033
+ )
2034
+
2035
+ # If there are modified variables, create an assignment statement
2036
+ # so the returned values are assigned back to the variables
2037
+ call_stmt = ir.Statement(span=_make_span(node))
2038
+ if modified_vars:
2039
+ # Create assignment: var1, var2 = fn(...) or var1 = fn(...)
2040
+ assign = ir.Assignment(value=fn_call_expr)
2041
+ assign.targets.extend(modified_vars)
2042
+ call_stmt.assignment.CopyFrom(assign)
2043
+ else:
2044
+ call_stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=fn_call_expr))
2045
+
2046
+ return [call_stmt]
2047
+
2048
+ def _visit_return(self, node: ast.Return) -> List[ir.Statement]:
2049
+ """Convert return statement to IR.
2050
+
2051
+ Return statements should only contain variables or literals, not action calls.
2052
+ If the return contains an action call, we normalize it:
2053
+ return await action()
2054
+ becomes:
2055
+ _return_tmp = await action()
2056
+ return _return_tmp
2057
+
2058
+ Constructor calls (like return MyModel(...)) are not supported and will
2059
+ raise an error with a recommendation to use an @action instead.
2060
+ """
2061
+ if node.value:
2062
+ # Check for constructor calls in return (e.g., return MyModel(...))
2063
+ self._check_constructor_in_return(node.value)
2064
+
2065
+ # Check if returning an action call - normalize to assignment + return
2066
+ action_call = self._extract_action_call(node.value)
2067
+ if action_call:
2068
+ # Create a temporary variable for the action result
2069
+ tmp_var = "_return_tmp"
2070
+
2071
+ # Create assignment: _return_tmp = await action()
2072
+ assign_stmt = ir.Statement(span=_make_span(node))
2073
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
2074
+ assign = ir.Assignment(targets=[tmp_var], value=value)
2075
+ assign_stmt.assignment.CopyFrom(assign)
2076
+
2077
+ # Create return: return _return_tmp
2078
+ return_stmt = ir.Statement(span=_make_span(node))
2079
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
2080
+ ret = ir.ReturnStmt(value=var_expr)
2081
+ return_stmt.return_stmt.CopyFrom(ret)
2082
+
2083
+ return [assign_stmt, return_stmt]
2084
+
2085
+ # Normalize return of function calls into assignment + return
2086
+ expr = self._expr_to_ir_with_model_coercion(node.value)
2087
+ if expr and expr.HasField("function_call"):
2088
+ tmp_var = self._ctx.next_implicit_fn_name(prefix="return_tmp")
2089
+
2090
+ assign_stmt = ir.Statement(span=_make_span(node))
2091
+ assign_stmt.assignment.CopyFrom(ir.Assignment(targets=[tmp_var], value=expr))
2092
+
2093
+ return_stmt = ir.Statement(span=_make_span(node))
2094
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
2095
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=var_expr))
2096
+ return [assign_stmt, return_stmt]
2097
+
2098
+ # Regular return with expression (variable, literal, etc.)
2099
+ if expr:
2100
+ stmt = ir.Statement(span=_make_span(node))
2101
+ return_stmt = ir.ReturnStmt(value=expr)
2102
+ stmt.return_stmt.CopyFrom(return_stmt)
2103
+ return [stmt]
2104
+
2105
+ # Return with no value
2106
+ stmt = ir.Statement(span=_make_span(node))
2107
+ stmt.return_stmt.CopyFrom(ir.ReturnStmt())
2108
+ return [stmt]
2109
+
2110
+ def _visit_break(self, node: ast.Break) -> List[ir.Statement]:
2111
+ """Convert break statement to IR."""
2112
+ stmt = ir.Statement(span=_make_span(node))
2113
+ stmt.break_stmt.CopyFrom(ir.BreakStmt())
2114
+ return [stmt]
2115
+
2116
+ def _visit_continue(self, node: ast.Continue) -> List[ir.Statement]:
2117
+ """Convert continue statement to IR."""
2118
+ stmt = ir.Statement(span=_make_span(node))
2119
+ stmt.continue_stmt.CopyFrom(ir.ContinueStmt())
2120
+ return [stmt]
2121
+
2122
+ def _visit_aug_assign(self, node: ast.AugAssign) -> List[ir.Statement]:
2123
+ """Convert augmented assignment (+=, -=, etc.) to IR."""
2124
+ # For now, we can represent this as a regular assignment with binary op
2125
+ # target op= value -> target = target op value
2126
+ stmt = ir.Statement(span=_make_span(node))
2127
+
2128
+ targets: List[str] = []
2129
+ if isinstance(node.target, ast.Name):
2130
+ targets.append(node.target.id)
2131
+
2132
+ left = self._expr_to_ir_with_model_coercion(node.target)
2133
+ right = self._expr_to_ir_with_model_coercion(node.value)
2134
+ if right and right.HasField("function_call"):
2135
+ tmp_var = self._ctx.next_implicit_fn_name(prefix="aug_tmp")
2136
+
2137
+ assign_tmp = ir.Statement(span=_make_span(node))
2138
+ assign_tmp.assignment.CopyFrom(
2139
+ ir.Assignment(
2140
+ targets=[tmp_var],
2141
+ value=ir.Expr(function_call=right.function_call, span=_make_span(node)),
2142
+ )
2143
+ )
2144
+
2145
+ if left:
2146
+ op = _bin_op_to_ir(node.op)
2147
+ if op:
2148
+ binary = ir.BinaryOp(
2149
+ left=left,
2150
+ op=op,
2151
+ right=ir.Expr(variable=ir.Variable(name=tmp_var)),
2152
+ )
2153
+ value = ir.Expr(binary_op=binary)
2154
+ assign = ir.Assignment(targets=targets, value=value)
2155
+ stmt.assignment.CopyFrom(assign)
2156
+ return [assign_tmp, stmt]
2157
+ return [assign_tmp]
2158
+
2159
+ if left and right:
2160
+ op = _bin_op_to_ir(node.op)
2161
+ if op:
2162
+ binary = ir.BinaryOp(left=left, op=op, right=right)
2163
+ value = ir.Expr(binary_op=binary)
2164
+ assign = ir.Assignment(targets=targets, value=value)
2165
+ stmt.assignment.CopyFrom(assign)
2166
+ return [stmt]
2167
+
2168
+ return []
2169
+
2170
+ def _collect_function_inputs(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> List[str]:
2171
+ """Collect workflow inputs from function parameters, including kw-only args."""
2172
+ args: List[str] = []
2173
+ seen: set[str] = set()
2174
+
2175
+ ordered_args = list(node.args.posonlyargs) + list(node.args.args)
2176
+ if ordered_args and ordered_args[0].arg == "self":
2177
+ ordered_args = ordered_args[1:]
2178
+
2179
+ for arg in ordered_args:
2180
+ if arg.arg not in seen:
2181
+ args.append(arg.arg)
2182
+ seen.add(arg.arg)
2183
+
2184
+ if node.args.vararg and node.args.vararg.arg not in seen:
2185
+ args.append(node.args.vararg.arg)
2186
+ seen.add(node.args.vararg.arg)
2187
+
2188
+ for arg in node.args.kwonlyargs:
2189
+ if arg.arg not in seen:
2190
+ args.append(arg.arg)
2191
+ seen.add(arg.arg)
2192
+
2193
+ if node.args.kwarg and node.args.kwarg.arg not in seen:
2194
+ args.append(node.args.kwarg.arg)
2195
+ seen.add(node.args.kwarg.arg)
2196
+
2197
+ return args
2198
+
2199
+ def _check_constructor_in_return(self, node: ast.expr) -> None:
2200
+ """Check for constructor calls in return statements.
2201
+
2202
+ Raises UnsupportedPatternError if the return value is a class instantiation
2203
+ like: return MyModel(field=value)
2204
+
2205
+ This is not supported because the workflow IR cannot serialize arbitrary
2206
+ object instantiation. Users should use an @action to create objects.
2207
+ """
2208
+ # Skip if it's an await (action call) - those are fine
2209
+ if isinstance(node, ast.Await):
2210
+ return
2211
+
2212
+ # Check for direct Call that looks like a constructor
2213
+ if isinstance(node, ast.Call):
2214
+ func_name = self._get_constructor_name(node.func)
2215
+ if func_name and self._looks_like_constructor(func_name, node):
2216
+ line = getattr(node, "lineno", None)
2217
+ col = getattr(node, "col_offset", None)
2218
+ raise UnsupportedPatternError(
2219
+ f"Returning constructor call '{func_name}(...)' is not supported",
2220
+ RECOMMENDATIONS["constructor_return"],
2221
+ line=line,
2222
+ col=col,
2223
+ )
2224
+
2225
+ def _check_constructor_in_assignment(self, node: ast.expr) -> None:
2226
+ """Check for constructor calls in assignments.
2227
+
2228
+ Raises UnsupportedPatternError if the assignment value is a class instantiation
2229
+ like: result = MyModel(field=value)
2230
+
2231
+ This is not supported because the workflow IR cannot serialize arbitrary
2232
+ object instantiation. Users should use an @action to create objects.
2233
+ """
2234
+ # Skip if it's an await (action call) - those are fine
2235
+ if isinstance(node, ast.Await):
2236
+ return
2237
+
2238
+ # Check for direct Call that looks like a constructor
2239
+ if isinstance(node, ast.Call):
2240
+ func_name = self._get_constructor_name(node.func)
2241
+ if func_name and self._looks_like_constructor(func_name, node):
2242
+ line = getattr(node, "lineno", None)
2243
+ col = getattr(node, "col_offset", None)
2244
+ raise UnsupportedPatternError(
2245
+ f"Assigning constructor call '{func_name}(...)' is not supported",
2246
+ RECOMMENDATIONS["constructor_assignment"],
2247
+ line=line,
2248
+ col=col,
2249
+ )
2250
+
2251
+ def _get_constructor_name(self, func: ast.expr) -> Optional[str]:
2252
+ """Get the name from a function expression if it looks like a constructor."""
2253
+ if isinstance(func, ast.Name):
2254
+ return func.id
2255
+ elif isinstance(func, ast.Attribute):
2256
+ return func.attr
2257
+ return None
2258
+
2259
+ def _looks_like_constructor(self, func_name: str, call: ast.Call) -> bool:
2260
+ """Check if a function call looks like a class constructor.
2261
+
2262
+ A constructor is identified by:
2263
+ 1. Name starts with uppercase (PEP8 convention for classes)
2264
+ 2. It's not a known action
2265
+ 3. It's not a known builtin like String operations
2266
+ 4. It's not a known Pydantic model or dataclass (those are allowed)
2267
+
2268
+ This is a heuristic - we can't perfectly distinguish constructors
2269
+ from functions without full type information.
2270
+ """
2271
+ # Check if first letter is uppercase (class naming convention)
2272
+ if not func_name or not func_name[0].isupper():
2273
+ return False
2274
+
2275
+ # If it's a known action, it's not a constructor
2276
+ if func_name in self._action_defs:
2277
+ return False
2278
+
2279
+ # If it's a known Pydantic model or dataclass, allow it
2280
+ # (it will be converted to a dict expression)
2281
+ if func_name in self._model_defs:
2282
+ return False
2283
+
2284
+ # Common builtins that start with uppercase but aren't constructors
2285
+ # (these are rarely used in workflow code but let's be safe)
2286
+ builtin_exceptions = {"True", "False", "None", "Ellipsis"}
2287
+ if func_name in builtin_exceptions:
2288
+ return False
2289
+
2290
+ return True
2291
+
2292
+ def _is_model_constructor(self, node: ast.expr) -> Optional[str]:
2293
+ """Check if an expression is a Pydantic model or dataclass constructor call.
2294
+
2295
+ Returns the model name if it is, None otherwise.
2296
+ """
2297
+ if not isinstance(node, ast.Call):
2298
+ return None
2299
+
2300
+ func_name = self._get_constructor_name(node.func)
2301
+ if func_name and func_name in self._model_defs:
2302
+ return func_name
2303
+
2304
+ return None
2305
+
2306
+ def _convert_model_constructor_to_dict(self, node: ast.Call, model_name: str) -> ir.Expr:
2307
+ """Convert a Pydantic model or dataclass constructor call to a dict expression.
2308
+
2309
+ For example:
2310
+ MyModel(field1=value1, field2=value2)
2311
+ becomes:
2312
+ {"field1": value1, "field2": value2}
2313
+
2314
+ Default values from the model definition are included for fields not
2315
+ explicitly provided in the constructor call.
2316
+ """
2317
+ model_def = self._model_defs[model_name]
2318
+ entries: List[ir.DictEntry] = []
2319
+
2320
+ # Track which fields were explicitly provided
2321
+ provided_fields: Set[str] = set()
2322
+
2323
+ # First, add all explicitly provided kwargs
2324
+ for kw in node.keywords:
2325
+ if kw.arg is None:
2326
+ # **kwargs expansion - not supported
2327
+ line = getattr(node, "lineno", None)
2328
+ col = getattr(node, "col_offset", None)
2329
+ raise UnsupportedPatternError(
2330
+ f"Model constructor '{model_name}' with **kwargs is not supported",
2331
+ "Use explicit keyword arguments instead of **kwargs.",
2332
+ line=line,
2333
+ col=col,
2334
+ )
2335
+
2336
+ provided_fields.add(kw.arg)
2337
+ key_expr = ir.Expr()
2338
+ key_literal = ir.Literal()
2339
+ key_literal.string_value = kw.arg
2340
+ key_expr.literal.CopyFrom(key_literal)
2341
+
2342
+ value_expr = self._expr_to_ir_with_model_coercion(kw.value)
2343
+ if value_expr is None:
2344
+ # If we can't convert the value, we need to raise an error
2345
+ line = getattr(node, "lineno", None)
2346
+ col = getattr(node, "col_offset", None)
2347
+ raise UnsupportedPatternError(
2348
+ f"Cannot convert value for field '{kw.arg}' in '{model_name}'",
2349
+ "Use simpler expressions (literals, variables, dicts, lists).",
2350
+ line=line,
2351
+ col=col,
2352
+ )
2353
+
2354
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2355
+
2356
+ # Handle positional arguments (dataclasses support this)
2357
+ if node.args:
2358
+ # For dataclasses, positional args map to fields in order
2359
+ field_names = list(model_def.fields.keys())
2360
+ for i, arg in enumerate(node.args):
2361
+ if i >= len(field_names):
2362
+ line = getattr(node, "lineno", None)
2363
+ col = getattr(node, "col_offset", None)
2364
+ raise UnsupportedPatternError(
2365
+ f"Too many positional arguments for '{model_name}'",
2366
+ "Use keyword arguments for clarity.",
2367
+ line=line,
2368
+ col=col,
2369
+ )
2370
+
2371
+ field_name = field_names[i]
2372
+ provided_fields.add(field_name)
2373
+
2374
+ key_expr = ir.Expr()
2375
+ key_literal = ir.Literal()
2376
+ key_literal.string_value = field_name
2377
+ key_expr.literal.CopyFrom(key_literal)
2378
+
2379
+ value_expr = self._expr_to_ir_with_model_coercion(arg)
2380
+ if value_expr is None:
2381
+ line = getattr(node, "lineno", None)
2382
+ col = getattr(node, "col_offset", None)
2383
+ raise UnsupportedPatternError(
2384
+ f"Cannot convert positional argument for field '{field_name}' in '{model_name}'",
2385
+ "Use simpler expressions (literals, variables, dicts, lists).",
2386
+ line=line,
2387
+ col=col,
2388
+ )
2389
+
2390
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2391
+
2392
+ # Add default values for fields not explicitly provided
2393
+ for field_name, field_def in model_def.fields.items():
2394
+ if field_name in provided_fields:
2395
+ continue
2396
+
2397
+ if field_def.has_default:
2398
+ key_expr = ir.Expr()
2399
+ key_literal = ir.Literal()
2400
+ key_literal.string_value = field_name
2401
+ key_expr.literal.CopyFrom(key_literal)
2402
+
2403
+ # Convert the default value to an IR literal
2404
+ default_literal = _constant_to_literal(field_def.default_value)
2405
+ if default_literal is None:
2406
+ # Can't serialize this default - skip it
2407
+ # (it's probably a complex object like a list factory)
2408
+ continue
2409
+
2410
+ value_expr = ir.Expr()
2411
+ value_expr.literal.CopyFrom(default_literal)
2412
+
2413
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2414
+
2415
+ result = ir.Expr(span=_make_span(node))
2416
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2417
+ return result
2418
+
2419
+ def _check_non_action_await(self, node: ast.Await) -> None:
2420
+ """Check if an await is for a non-action function.
2421
+
2422
+ Note: We can only reliably detect non-action awaits for functions defined
2423
+ in the same module. Actions imported from other modules will pass through
2424
+ and may fail at runtime if they're not actually actions.
2425
+
2426
+ For now, we only check against common builtins and known non-action patterns.
2427
+ A runtime check will catch functions that aren't registered actions.
2428
+ """
2429
+ awaited = node.value
2430
+ if not isinstance(awaited, ast.Call):
2431
+ return
2432
+
2433
+ # Skip special cases that are handled elsewhere
2434
+ if self._is_run_action_call(awaited):
2435
+ return
2436
+ if self._is_asyncio_sleep_call(awaited):
2437
+ return
2438
+ if self._is_asyncio_gather_call(awaited):
2439
+ return
2440
+
2441
+ # Get the function name
2442
+ func_name = None
2443
+ if isinstance(awaited.func, ast.Name):
2444
+ func_name = awaited.func.id
2445
+ elif isinstance(awaited.func, ast.Attribute):
2446
+ func_name = awaited.func.attr
2447
+
2448
+ if not func_name:
2449
+ return
2450
+
2451
+ # Only raise error for functions defined in THIS module that we know
2452
+ # are NOT actions (i.e., async functions without @action decorator)
2453
+ # We can't reliably detect imported non-actions without full type info.
2454
+ #
2455
+ # The check works by looking at _module_functions which contains
2456
+ # functions defined in the same module as the workflow.
2457
+ if func_name in getattr(self, "_module_functions", set()):
2458
+ if func_name not in self._action_defs:
2459
+ line = getattr(node, "lineno", None)
2460
+ col = getattr(node, "col_offset", None)
2461
+ raise UnsupportedPatternError(
2462
+ f"Awaiting non-action function '{func_name}()' is not supported",
2463
+ RECOMMENDATIONS["non_action_call"],
2464
+ line=line,
2465
+ col=col,
2466
+ )
2467
+
2468
+ def _check_sync_function_call(self, node: ast.Call) -> None:
2469
+ """Check for synchronous function calls that should be in actions.
2470
+
2471
+ Common patterns like len(), str(), etc. are not supported in workflow code.
2472
+ """
2473
+ func_name = None
2474
+ if isinstance(node.func, ast.Name):
2475
+ func_name = node.func.id
2476
+ elif isinstance(node.func, ast.Attribute):
2477
+ # Method calls on objects - check the method name
2478
+ func_name = node.func.attr
2479
+
2480
+ if not func_name:
2481
+ return
2482
+
2483
+ # Builtins that users commonly try to use
2484
+ common_builtins = {
2485
+ "len",
2486
+ "str",
2487
+ "int",
2488
+ "float",
2489
+ "bool",
2490
+ "list",
2491
+ "dict",
2492
+ "set",
2493
+ "tuple",
2494
+ "sum",
2495
+ "min",
2496
+ "max",
2497
+ "sorted",
2498
+ "reversed",
2499
+ "enumerate",
2500
+ "zip",
2501
+ "map",
2502
+ "filter",
2503
+ "range",
2504
+ "abs",
2505
+ "round",
2506
+ "print",
2507
+ "type",
2508
+ "isinstance",
2509
+ "hasattr",
2510
+ "getattr",
2511
+ "setattr",
2512
+ "open",
2513
+ "format",
2514
+ }
2515
+
2516
+ if func_name in common_builtins:
2517
+ line = getattr(node, "lineno", None)
2518
+ col = getattr(node, "col_offset", None)
2519
+ raise UnsupportedPatternError(
2520
+ f"Calling built-in function '{func_name}()' directly is not supported",
2521
+ RECOMMENDATIONS["builtin_call"],
2522
+ line=line,
2523
+ col=col,
2524
+ )
2525
+
2526
+ def _extract_action_call(self, node: ast.expr) -> Optional[ir.ActionCall]:
2527
+ """Extract an action call from an expression if present.
2528
+
2529
+ Also validates that awaited calls are actually @action decorated functions.
2530
+ Raises UnsupportedPatternError if awaiting a non-action function.
2531
+ """
2532
+ if not isinstance(node, ast.Await):
2533
+ return None
2534
+
2535
+ awaited = node.value
2536
+ # Handle self.run_action(...) wrapper
2537
+ if isinstance(awaited, ast.Call):
2538
+ if self._is_run_action_call(awaited):
2539
+ # Extract the actual action call from run_action
2540
+ if awaited.args:
2541
+ action_call = self._extract_action_call_from_awaitable(awaited.args[0])
2542
+ if action_call:
2543
+ # Extract policies from run_action kwargs (retry, timeout)
2544
+ self._extract_policies_from_run_action(awaited, action_call)
2545
+ return action_call
2546
+ # Check for asyncio.sleep() - convert to @sleep action
2547
+ if self._is_asyncio_sleep_call(awaited):
2548
+ return self._convert_asyncio_sleep_to_action(awaited)
2549
+ # Try to extract as action call
2550
+ action_call = self._extract_action_call_from_call(awaited)
2551
+ if action_call:
2552
+ return action_call
2553
+
2554
+ # If we get here, it's an await of a non-action function
2555
+ self._check_non_action_await(node)
2556
+ return None
2557
+
2558
+ return None
2559
+
2560
+ def _is_run_action_call(self, node: ast.Call) -> bool:
2561
+ """Check if this is a self.run_action(...) call."""
2562
+ if isinstance(node.func, ast.Attribute):
2563
+ return node.func.attr == "run_action"
2564
+ return False
2565
+
2566
+ def _extract_policies_from_run_action(
2567
+ self, run_action_call: ast.Call, action_call: ir.ActionCall
2568
+ ) -> None:
2569
+ """Extract retry and timeout policies from run_action kwargs.
2570
+
2571
+ Parses patterns like:
2572
+ - self.run_action(action(), retry=RetryPolicy(attempts=3))
2573
+ - self.run_action(action(), timeout=timedelta(seconds=30))
2574
+ - self.run_action(action(), timeout=60)
2575
+ """
2576
+ for kw in run_action_call.keywords:
2577
+ if kw.arg == "retry":
2578
+ retry_policy = self._parse_retry_policy(kw.value)
2579
+ if retry_policy:
2580
+ policy_bracket = ir.PolicyBracket()
2581
+ policy_bracket.retry.CopyFrom(retry_policy)
2582
+ action_call.policies.append(policy_bracket)
2583
+ elif kw.arg == "timeout":
2584
+ timeout_policy = self._parse_timeout_policy(kw.value)
2585
+ if timeout_policy:
2586
+ policy_bracket = ir.PolicyBracket()
2587
+ policy_bracket.timeout.CopyFrom(timeout_policy)
2588
+ action_call.policies.append(policy_bracket)
2589
+
2590
+ def _parse_retry_policy(self, node: ast.expr) -> Optional[ir.RetryPolicy]:
2591
+ """Parse a RetryPolicy(...) call into IR.
2592
+
2593
+ Supports:
2594
+ - RetryPolicy(attempts=3)
2595
+ - RetryPolicy(attempts=3, exception_types=["ValueError"])
2596
+ - RetryPolicy(attempts=3, backoff_seconds=5)
2597
+ - self.retry_policy (instance attribute reference)
2598
+ """
2599
+ # Handle self.attr pattern - look up in instance attrs
2600
+ if (
2601
+ isinstance(node, ast.Attribute)
2602
+ and isinstance(node.value, ast.Name)
2603
+ and node.value.id == "self"
2604
+ ):
2605
+ attr_name = node.attr
2606
+ if attr_name in self._instance_attrs:
2607
+ return self._parse_retry_policy(self._instance_attrs[attr_name])
2608
+ return None
2609
+
2610
+ if not isinstance(node, ast.Call):
2611
+ return None
2612
+
2613
+ # Check if it's a RetryPolicy call
2614
+ func_name = None
2615
+ if isinstance(node.func, ast.Name):
2616
+ func_name = node.func.id
2617
+ elif isinstance(node.func, ast.Attribute):
2618
+ func_name = node.func.attr
2619
+
2620
+ if func_name != "RetryPolicy":
2621
+ return None
2622
+
2623
+ policy = ir.RetryPolicy()
2624
+
2625
+ for kw in node.keywords:
2626
+ if kw.arg == "attempts" and isinstance(kw.value, ast.Constant):
2627
+ # attempts means total executions, max_retries means retries after first attempt
2628
+ # So attempts=1 -> max_retries=0 (no retries), attempts=3 -> max_retries=2
2629
+ policy.max_retries = kw.value.value - 1
2630
+ elif kw.arg == "exception_types" and isinstance(kw.value, ast.List):
2631
+ for elt in kw.value.elts:
2632
+ if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
2633
+ policy.exception_types.append(elt.value)
2634
+ elif kw.arg == "backoff_seconds" and isinstance(kw.value, ast.Constant):
2635
+ policy.backoff.seconds = int(kw.value.value)
2636
+
2637
+ return policy
2638
+
2639
+ def _parse_timeout_policy(self, node: ast.expr) -> Optional[ir.TimeoutPolicy]:
2640
+ """Parse a timeout value into IR.
2641
+
2642
+ Supports:
2643
+ - timeout=60 (int seconds)
2644
+ - timeout=30.5 (float seconds)
2645
+ - timeout=timedelta(seconds=30)
2646
+ - timeout=timedelta(minutes=2)
2647
+ - self.timeout (instance attribute reference)
2648
+ """
2649
+ # Handle self.attr pattern - look up in instance attrs
2650
+ if (
2651
+ isinstance(node, ast.Attribute)
2652
+ and isinstance(node.value, ast.Name)
2653
+ and node.value.id == "self"
2654
+ ):
2655
+ attr_name = node.attr
2656
+ if attr_name in self._instance_attrs:
2657
+ return self._parse_timeout_policy(self._instance_attrs[attr_name])
2658
+ return None
2659
+
2660
+ policy = ir.TimeoutPolicy()
2661
+
2662
+ # Direct numeric value (seconds)
2663
+ if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
2664
+ policy.timeout.seconds = int(node.value)
2665
+ return policy
2666
+
2667
+ # timedelta(...) call
2668
+ if isinstance(node, ast.Call):
2669
+ func_name = None
2670
+ if isinstance(node.func, ast.Name):
2671
+ func_name = node.func.id
2672
+ elif isinstance(node.func, ast.Attribute):
2673
+ func_name = node.func.attr
2674
+
2675
+ if func_name == "timedelta":
2676
+ total_seconds = 0
2677
+ for kw in node.keywords:
2678
+ if isinstance(kw.value, ast.Constant):
2679
+ val = kw.value.value
2680
+ if kw.arg == "seconds":
2681
+ total_seconds += int(val)
2682
+ elif kw.arg == "minutes":
2683
+ total_seconds += int(val) * 60
2684
+ elif kw.arg == "hours":
2685
+ total_seconds += int(val) * 3600
2686
+ elif kw.arg == "days":
2687
+ total_seconds += int(val) * 86400
2688
+ policy.timeout.seconds = total_seconds
2689
+ return policy
2690
+
2691
+ return None
2692
+
2693
+ def _is_asyncio_sleep_call(self, node: ast.Call) -> bool:
2694
+ """Check if this is an asyncio.sleep(...) call.
2695
+
2696
+ Supports both patterns:
2697
+ - import asyncio; asyncio.sleep(1)
2698
+ - from asyncio import sleep; sleep(1)
2699
+ - from asyncio import sleep as s; s(1)
2700
+ """
2701
+ if isinstance(node.func, ast.Attribute):
2702
+ # asyncio.sleep(...) pattern
2703
+ if node.func.attr == "sleep" and isinstance(node.func.value, ast.Name):
2704
+ return node.func.value.id == "asyncio"
2705
+ elif isinstance(node.func, ast.Name):
2706
+ # sleep(...) pattern - check if it's imported from asyncio
2707
+ func_name = node.func.id
2708
+ if func_name in self._imported_names:
2709
+ imported = self._imported_names[func_name]
2710
+ return imported.module == "asyncio" and imported.original_name == "sleep"
2711
+ return False
2712
+
2713
+ def _convert_asyncio_sleep_to_action(self, node: ast.Call) -> ir.ActionCall:
2714
+ """Convert asyncio.sleep(duration) to @sleep(duration=X) action call.
2715
+
2716
+ This creates a built-in sleep action that the scheduler handles as a
2717
+ durable sleep - stored in the DB with a future scheduled_at time.
2718
+ """
2719
+ action_call = ir.ActionCall(action_name="sleep")
2720
+
2721
+ # Extract duration argument (positional or keyword)
2722
+ if node.args:
2723
+ # asyncio.sleep(1) - positional
2724
+ expr = self._expr_to_ir_with_model_coercion(node.args[0])
2725
+ if expr:
2726
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2727
+ elif node.keywords:
2728
+ # asyncio.sleep(seconds=1) - keyword (less common)
2729
+ for kw in node.keywords:
2730
+ if kw.arg in ("seconds", "delay", "duration"):
2731
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
2732
+ if expr:
2733
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2734
+ break
2735
+
2736
+ return action_call
2737
+
2738
+ def _is_asyncio_gather_call(self, node: ast.Call) -> bool:
2739
+ """Check if this is an asyncio.gather(...) call.
2740
+
2741
+ Supports both patterns:
2742
+ - import asyncio; asyncio.gather(a(), b())
2743
+ - from asyncio import gather; gather(a(), b())
2744
+ - from asyncio import gather as g; g(a(), b())
2745
+ """
2746
+ if isinstance(node.func, ast.Attribute):
2747
+ # asyncio.gather(...) pattern
2748
+ if node.func.attr == "gather" and isinstance(node.func.value, ast.Name):
2749
+ return node.func.value.id == "asyncio"
2750
+ elif isinstance(node.func, ast.Name):
2751
+ # gather(...) pattern - check if it's imported from asyncio
2752
+ func_name = node.func.id
2753
+ if func_name in self._imported_names:
2754
+ imported = self._imported_names[func_name]
2755
+ return imported.module == "asyncio" and imported.original_name == "gather"
2756
+ return False
2757
+
2758
+ def _convert_asyncio_gather(
2759
+ self, node: ast.Call
2760
+ ) -> Optional[Union[ir.ParallelExpr, ir.SpreadExpr]]:
2761
+ """Convert asyncio.gather(...) to ParallelExpr or SpreadExpr IR.
2762
+
2763
+ Handles two patterns:
2764
+ 1. Static gather: asyncio.gather(a(), b(), c()) -> ParallelExpr
2765
+ 2. Spread gather: asyncio.gather(*[action(x) for x in items]) -> SpreadExpr
2766
+
2767
+ Args:
2768
+ node: The asyncio.gather() Call node
2769
+
2770
+ Returns:
2771
+ A ParallelExpr, SpreadExpr, or None if conversion fails.
2772
+ """
2773
+ # Check for starred expressions - spread pattern
2774
+ if len(node.args) == 1 and isinstance(node.args[0], ast.Starred):
2775
+ starred = node.args[0]
2776
+ # Only list comprehensions are supported for spread
2777
+ if isinstance(starred.value, ast.ListComp):
2778
+ return self._convert_listcomp_to_spread_expr(starred.value)
2779
+ else:
2780
+ # Spreading a variable or other expression is not supported
2781
+ line = getattr(node, "lineno", None)
2782
+ col = getattr(node, "col_offset", None)
2783
+ if isinstance(starred.value, ast.Name):
2784
+ var_name = starred.value.id
2785
+ raise UnsupportedPatternError(
2786
+ f"Spreading variable '{var_name}' in asyncio.gather() is not supported",
2787
+ RECOMMENDATIONS["gather_variable_spread"],
2788
+ line=line,
2789
+ col=col,
2790
+ )
2791
+ else:
2792
+ raise UnsupportedPatternError(
2793
+ "Spreading non-list-comprehension expressions in asyncio.gather() is not supported",
2794
+ RECOMMENDATIONS["gather_variable_spread"],
2795
+ line=line,
2796
+ col=col,
2797
+ )
2798
+
2799
+ # Standard case: gather(a(), b(), c()) -> ParallelExpr
2800
+ parallel = ir.ParallelExpr()
2801
+
2802
+ # Each argument to gather() should be an action call
2803
+ for arg in node.args:
2804
+ call = self._convert_gather_arg_to_call(arg)
2805
+ if call:
2806
+ parallel.calls.append(call)
2807
+
2808
+ # Only return if we have calls
2809
+ if not parallel.calls:
2810
+ return None
2811
+
2812
+ return parallel
2813
+
2814
+ def _convert_listcomp_to_spread_expr(self, listcomp: ast.ListComp) -> Optional[ir.SpreadExpr]:
2815
+ """Convert a list comprehension to SpreadExpr IR.
2816
+
2817
+ Handles patterns like:
2818
+ [action(x=item) for item in collection]
2819
+
2820
+ The comprehension must have exactly one generator with no conditions,
2821
+ and the element must be an action call.
2822
+
2823
+ Args:
2824
+ listcomp: The ListComp AST node
2825
+
2826
+ Returns:
2827
+ A SpreadExpr, or None if conversion fails.
2828
+ """
2829
+ # Only support simple list comprehensions with one generator
2830
+ if len(listcomp.generators) != 1:
2831
+ line = getattr(listcomp, "lineno", None)
2832
+ col = getattr(listcomp, "col_offset", None)
2833
+ raise UnsupportedPatternError(
2834
+ "Spread pattern only supports a single loop variable",
2835
+ "Use a simple list comprehension: [action(x) for x in items]",
2836
+ line=line,
2837
+ col=col,
2838
+ )
2839
+
2840
+ gen = listcomp.generators[0]
2841
+
2842
+ # Check for conditions - not supported
2843
+ if gen.ifs:
2844
+ line = getattr(listcomp, "lineno", None)
2845
+ col = getattr(listcomp, "col_offset", None)
2846
+ raise UnsupportedPatternError(
2847
+ "Spread pattern does not support conditions in list comprehension",
2848
+ "Remove the 'if' clause from the comprehension",
2849
+ line=line,
2850
+ col=col,
2851
+ )
2852
+
2853
+ # Get the loop variable name
2854
+ if not isinstance(gen.target, ast.Name):
2855
+ line = getattr(listcomp, "lineno", None)
2856
+ col = getattr(listcomp, "col_offset", None)
2857
+ raise UnsupportedPatternError(
2858
+ "Spread pattern requires a simple loop variable",
2859
+ "Use a simple variable: [action(x) for x in items]",
2860
+ line=line,
2861
+ col=col,
2862
+ )
2863
+ loop_var = gen.target.id
2864
+
2865
+ # Get the collection expression
2866
+ collection_expr = self._expr_to_ir_with_model_coercion(gen.iter)
2867
+ if not collection_expr:
2868
+ line = getattr(listcomp, "lineno", None)
2869
+ col = getattr(listcomp, "col_offset", None)
2870
+ raise UnsupportedPatternError(
2871
+ "Could not convert collection expression in spread pattern",
2872
+ "Ensure the collection is a simple variable or expression",
2873
+ line=line,
2874
+ col=col,
2875
+ )
2876
+
2877
+ # The element must be an action call
2878
+ if not isinstance(listcomp.elt, ast.Call):
2879
+ line = getattr(listcomp, "lineno", None)
2880
+ col = getattr(listcomp, "col_offset", None)
2881
+ raise UnsupportedPatternError(
2882
+ "Spread pattern requires an action call in the list comprehension",
2883
+ "Use: [action(x=item) for item in items]",
2884
+ line=line,
2885
+ col=col,
2886
+ )
2887
+
2888
+ action_call = self._extract_action_call_from_call(listcomp.elt)
2889
+ if not action_call:
2890
+ line = getattr(listcomp, "lineno", None)
2891
+ col = getattr(listcomp, "col_offset", None)
2892
+ raise UnsupportedPatternError(
2893
+ "Spread pattern element must be an @action call",
2894
+ "Ensure the function is decorated with @action",
2895
+ line=line,
2896
+ col=col,
2897
+ )
2898
+
2899
+ # Build the SpreadExpr
2900
+ spread = ir.SpreadExpr()
2901
+ spread.collection.CopyFrom(collection_expr)
2902
+ spread.loop_var = loop_var
2903
+ spread.action.CopyFrom(action_call)
2904
+
2905
+ return spread
2906
+
2907
+ def _convert_gather_arg_to_call(self, node: ast.expr) -> Optional[ir.Call]:
2908
+ """Convert a gather argument to an IR Call.
2909
+
2910
+ Handles both action calls and regular function calls.
2911
+ """
2912
+ if not isinstance(node, ast.Call):
2913
+ return None
2914
+
2915
+ # Try to extract as an action call first
2916
+ action_call = self._extract_action_call_from_call(node)
2917
+ if action_call:
2918
+ call = ir.Call()
2919
+ call.action.CopyFrom(action_call)
2920
+ return call
2921
+
2922
+ # Fall back to regular function call
2923
+ func_call = self._convert_to_function_call(node)
2924
+ if func_call:
2925
+ call = ir.Call()
2926
+ call.function.CopyFrom(func_call)
2927
+ return call
2928
+
2929
+ return None
2930
+
2931
+ def _convert_to_function_call(self, node: ast.Call) -> Optional[ir.FunctionCall]:
2932
+ """Convert an AST Call to IR FunctionCall."""
2933
+ func_name = self._get_func_name(node.func)
2934
+ if not func_name:
2935
+ return None
2936
+
2937
+ fn_call = ir.FunctionCall(name=func_name)
2938
+ global_function = _global_function_for_call(func_name, node)
2939
+ if global_function is not None:
2940
+ fn_call.global_function = global_function
2941
+
2942
+ # Add positional args
2943
+ for arg in node.args:
2944
+ expr = self._expr_to_ir_with_model_coercion(arg)
2945
+ if expr:
2946
+ fn_call.args.append(expr)
2947
+
2948
+ # Add keyword args
2949
+ for kw in node.keywords:
2950
+ if kw.arg:
2951
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
2952
+ if expr:
2953
+ fn_call.kwargs.append(ir.Kwarg(name=kw.arg, value=expr))
2954
+
2955
+ # Fill in missing kwargs with default values from function signature
2956
+ self._fill_default_kwargs_for_expr(fn_call)
2957
+
2958
+ return fn_call
2959
+
2960
+ def _get_func_name(self, node: ast.expr) -> Optional[str]:
2961
+ """Get function name from a func node."""
2962
+ if isinstance(node, ast.Name):
2963
+ return node.id
2964
+ elif isinstance(node, ast.Attribute):
2965
+ # Handle chained attributes like obj.method
2966
+ parts = []
2967
+ current = node
2968
+ while isinstance(current, ast.Attribute):
2969
+ parts.append(current.attr)
2970
+ current = current.value
2971
+ if isinstance(current, ast.Name):
2972
+ parts.append(current.id)
2973
+ name = ".".join(reversed(parts))
2974
+ if name.startswith("self."):
2975
+ return name[5:]
2976
+ return name
2977
+ return None
2978
+
2979
+ def _convert_model_constructor_if_needed(self, node: ast.Call) -> Optional[ir.Expr]:
2980
+ model_name = self._is_model_constructor(node)
2981
+ if model_name:
2982
+ return self._convert_model_constructor_to_dict(node, model_name)
2983
+ return None
2984
+
2985
+ def _resolve_enum_attribute(self, node: ast.Attribute) -> Optional[ir.Expr]:
2986
+ value = _resolve_enum_attribute_value(node, self._module_globals)
2987
+ if value is None:
2988
+ return None
2989
+ literal = _constant_to_literal(value)
2990
+ if literal is None:
2991
+ line = getattr(node, "lineno", None)
2992
+ col = getattr(node, "col_offset", None)
2993
+ raise UnsupportedPatternError(
2994
+ "Enum value must be a primitive literal",
2995
+ RECOMMENDATIONS["unsupported_literal"],
2996
+ line=line,
2997
+ col=col,
2998
+ )
2999
+ expr = ir.Expr(span=_make_span(node))
3000
+ expr.literal.CopyFrom(literal)
3001
+ return expr
3002
+
3003
+ def _expr_to_ir_with_model_coercion(self, node: ast.expr) -> Optional[ir.Expr]:
3004
+ """Convert an AST expression to IR, converting model constructors to dicts."""
3005
+ result = _expr_to_ir(
3006
+ node,
3007
+ model_converter=self._convert_model_constructor_if_needed,
3008
+ enum_resolver=self._resolve_enum_attribute,
3009
+ )
3010
+ # Post-process to fill in default kwargs for function calls (recursively)
3011
+ if result is not None:
3012
+ self._fill_default_kwargs_recursive(result)
3013
+ return result
3014
+
3015
+ def _fill_default_kwargs_recursive(self, expr: ir.Expr) -> None:
3016
+ """Recursively fill in default kwargs for all function calls in an expression."""
3017
+ if expr.HasField("function_call"):
3018
+ self._fill_default_kwargs_for_expr(expr.function_call)
3019
+ # Recurse into function call args and kwargs
3020
+ for arg in expr.function_call.args:
3021
+ self._fill_default_kwargs_recursive(arg)
3022
+ for kwarg in expr.function_call.kwargs:
3023
+ if kwarg.value:
3024
+ self._fill_default_kwargs_recursive(kwarg.value)
3025
+ elif expr.HasField("binary_op"):
3026
+ if expr.binary_op.left:
3027
+ self._fill_default_kwargs_recursive(expr.binary_op.left)
3028
+ if expr.binary_op.right:
3029
+ self._fill_default_kwargs_recursive(expr.binary_op.right)
3030
+ elif expr.HasField("unary_op"):
3031
+ if expr.unary_op.operand:
3032
+ self._fill_default_kwargs_recursive(expr.unary_op.operand)
3033
+ elif expr.HasField("list"):
3034
+ for elem in expr.list.elements:
3035
+ self._fill_default_kwargs_recursive(elem)
3036
+ elif expr.HasField("dict"):
3037
+ for entry in expr.dict.entries:
3038
+ if entry.key:
3039
+ self._fill_default_kwargs_recursive(entry.key)
3040
+ if entry.value:
3041
+ self._fill_default_kwargs_recursive(entry.value)
3042
+ elif expr.HasField("index"):
3043
+ if expr.index.object:
3044
+ self._fill_default_kwargs_recursive(expr.index.object)
3045
+ if expr.index.index:
3046
+ self._fill_default_kwargs_recursive(expr.index.index)
3047
+ elif expr.HasField("dot"):
3048
+ if expr.dot.object:
3049
+ self._fill_default_kwargs_recursive(expr.dot.object)
3050
+
3051
+ def _fill_default_kwargs_for_expr(self, fn_call: ir.FunctionCall) -> None:
3052
+ """Fill in missing kwargs with default values from the function signature."""
3053
+ sig_info = self._ctx.function_signatures.get(fn_call.name)
3054
+ if sig_info is None:
3055
+ return
3056
+
3057
+ # Track which parameters are already provided
3058
+ provided_by_position: Set[str] = set()
3059
+ provided_by_kwarg: Set[str] = set()
3060
+
3061
+ # Positional args map to parameters in order
3062
+ for idx, _arg in enumerate(fn_call.args):
3063
+ if idx < len(sig_info.parameters):
3064
+ provided_by_position.add(sig_info.parameters[idx].name)
3065
+
3066
+ # Kwargs are named
3067
+ for kwarg in fn_call.kwargs:
3068
+ provided_by_kwarg.add(kwarg.name)
3069
+
3070
+ # Add defaults for missing parameters
3071
+ for param in sig_info.parameters:
3072
+ if param.name in provided_by_position or param.name in provided_by_kwarg:
3073
+ continue
3074
+ if not param.has_default:
3075
+ continue
3076
+
3077
+ # Convert the default value to an IR expression
3078
+ literal = _constant_to_literal(param.default_value)
3079
+ if literal is not None:
3080
+ expr = ir.Expr()
3081
+ expr.literal.CopyFrom(literal)
3082
+ fn_call.kwargs.append(ir.Kwarg(name=param.name, value=expr))
3083
+
3084
+ def _extract_action_call_from_awaitable(self, node: ast.expr) -> Optional[ir.ActionCall]:
3085
+ """Extract action call from an awaitable expression."""
3086
+ if isinstance(node, ast.Call):
3087
+ return self._extract_action_call_from_call(node)
3088
+ return None
3089
+
3090
+ def _extract_action_call_from_call(self, node: ast.Call) -> Optional[ir.ActionCall]:
3091
+ """Extract action call info from a Call node.
3092
+
3093
+ Converts positional arguments to keyword arguments using the action's
3094
+ signature introspection. This ensures all arguments are named in the IR.
3095
+
3096
+ Pydantic models and dataclass constructors passed as arguments are
3097
+ automatically converted to dict expressions.
3098
+ """
3099
+ action_name = self._get_action_name(node.func)
3100
+ if not action_name:
3101
+ return None
3102
+
3103
+ if action_name not in self._action_defs:
3104
+ return None
3105
+
3106
+ action_def = self._action_defs[action_name]
3107
+ action_call = ir.ActionCall(action_name=action_def.action_name)
3108
+
3109
+ # Set the module name so the worker knows where to find the action
3110
+ if action_def.module_name:
3111
+ action_call.module_name = action_def.module_name
3112
+
3113
+ # Get parameter names from signature for positional arg conversion
3114
+ param_names = list(action_def.signature.parameters.keys())
3115
+
3116
+ # Convert positional args to kwargs using signature introspection
3117
+ # Model constructors are converted to dict expressions
3118
+ for i, arg in enumerate(node.args):
3119
+ if i < len(param_names):
3120
+ expr = self._expr_to_ir_with_model_coercion(arg)
3121
+ if expr:
3122
+ kwarg = ir.Kwarg(name=param_names[i], value=expr)
3123
+ action_call.kwargs.append(kwarg)
3124
+
3125
+ # Add explicit kwargs
3126
+ # Model constructors are converted to dict expressions
3127
+ for kw in node.keywords:
3128
+ if kw.arg:
3129
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
3130
+ if expr:
3131
+ kwarg = ir.Kwarg(name=kw.arg, value=expr)
3132
+ action_call.kwargs.append(kwarg)
3133
+
3134
+ return action_call
3135
+
3136
+ def _get_action_name(self, func: ast.expr) -> Optional[str]:
3137
+ """Get the action name from a function expression."""
3138
+ if isinstance(func, ast.Name):
3139
+ return func.id
3140
+ elif isinstance(func, ast.Attribute):
3141
+ return func.attr
3142
+ return None
3143
+
3144
+ def _get_assign_target(self, targets: List[ast.expr]) -> Optional[str]:
3145
+ """Get the target variable name from assignment targets (single target only)."""
3146
+ if targets and isinstance(targets[0], ast.Name):
3147
+ return targets[0].id
3148
+ return None
3149
+
3150
+ def _get_assign_targets(self, targets: List[ast.expr]) -> List[str]:
3151
+ """Get all target variable names from assignment targets (including tuple unpacking)."""
3152
+ result: List[str] = []
3153
+ for t in targets:
3154
+ if isinstance(t, ast.Name):
3155
+ result.append(t.id)
3156
+ elif isinstance(t, ast.Subscript):
3157
+ formatted = _format_subscript_target(t)
3158
+ if formatted:
3159
+ result.append(formatted)
3160
+ elif isinstance(t, ast.Tuple):
3161
+ for elt in t.elts:
3162
+ if isinstance(elt, ast.Name):
3163
+ result.append(elt.id)
3164
+ return result
3165
+
3166
+
3167
+ def _make_span(node: ast.AST) -> ir.Span:
3168
+ """Create a Span from an AST node."""
3169
+ return ir.Span(
3170
+ start_line=getattr(node, "lineno", 0),
3171
+ start_col=getattr(node, "col_offset", 0),
3172
+ end_line=getattr(node, "end_lineno", 0) or 0,
3173
+ end_col=getattr(node, "end_col_offset", 0) or 0,
3174
+ )
3175
+
3176
+
3177
+ def _attribute_chain(node: ast.Attribute) -> Optional[List[str]]:
3178
+ parts: List[str] = []
3179
+ current: ast.AST = node
3180
+ while isinstance(current, ast.Attribute):
3181
+ parts.append(current.attr)
3182
+ current = current.value
3183
+ if isinstance(current, ast.Name):
3184
+ parts.append(current.id)
3185
+ return list(reversed(parts))
3186
+ return None
3187
+
3188
+
3189
+ def _resolve_enum_attribute_value(
3190
+ node: ast.Attribute,
3191
+ module_globals: Mapping[str, Any],
3192
+ ) -> Optional[Any]:
3193
+ chain = _attribute_chain(node)
3194
+ if not chain or len(chain) < 2:
3195
+ return None
3196
+
3197
+ current = module_globals.get(chain[0])
3198
+ if current is None:
3199
+ return None
3200
+
3201
+ for part in chain[1:-1]:
3202
+ try:
3203
+ current_dict = current.__dict__
3204
+ except AttributeError:
3205
+ return None
3206
+ current = current_dict.get(part)
3207
+ if current is None:
3208
+ return None
3209
+
3210
+ member_name = chain[-1]
3211
+ if isinstance(current, EnumMeta):
3212
+ member = current.__members__.get(member_name)
3213
+ if member is None:
3214
+ return None
3215
+ return member.value
3216
+
3217
+ return None
3218
+
3219
+
3220
+ def _expr_to_ir(
3221
+ expr: ast.AST,
3222
+ model_converter: Optional[Callable[[ast.Call], Optional[ir.Expr]]] = None,
3223
+ enum_resolver: Optional[Callable[[ast.Attribute], Optional[ir.Expr]]] = None,
3224
+ ) -> Optional[ir.Expr]:
3225
+ """Convert Python AST expression to IR Expr."""
3226
+ result = ir.Expr(span=_make_span(expr))
3227
+
3228
+ if isinstance(expr, ast.Call) and model_converter:
3229
+ converted = model_converter(expr)
3230
+ if converted:
3231
+ return converted
3232
+
3233
+ if isinstance(expr, ast.Name):
3234
+ result.variable.CopyFrom(ir.Variable(name=expr.id))
3235
+ return result
3236
+
3237
+ if isinstance(expr, ast.Constant):
3238
+ literal = _constant_to_literal(expr.value)
3239
+ if literal:
3240
+ result.literal.CopyFrom(literal)
3241
+ return result
3242
+
3243
+ if isinstance(expr, ast.BinOp):
3244
+ left = _expr_to_ir(
3245
+ expr.left,
3246
+ model_converter=model_converter,
3247
+ enum_resolver=enum_resolver,
3248
+ )
3249
+ right = _expr_to_ir(
3250
+ expr.right,
3251
+ model_converter=model_converter,
3252
+ enum_resolver=enum_resolver,
3253
+ )
3254
+ op = _bin_op_to_ir(expr.op)
3255
+ if left and right and op:
3256
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
3257
+ return result
3258
+
3259
+ if isinstance(expr, ast.UnaryOp):
3260
+ operand = _expr_to_ir(
3261
+ expr.operand,
3262
+ model_converter=model_converter,
3263
+ enum_resolver=enum_resolver,
3264
+ )
3265
+ op = _unary_op_to_ir(expr.op)
3266
+ if operand and op:
3267
+ result.unary_op.CopyFrom(ir.UnaryOp(op=op, operand=operand))
3268
+ return result
3269
+
3270
+ if isinstance(expr, ast.Compare):
3271
+ left = _expr_to_ir(
3272
+ expr.left,
3273
+ model_converter=model_converter,
3274
+ enum_resolver=enum_resolver,
3275
+ )
3276
+ if not left:
3277
+ return None
3278
+ # For simplicity, handle single comparison
3279
+ if expr.ops and expr.comparators:
3280
+ op = _cmp_op_to_ir(expr.ops[0])
3281
+ right = _expr_to_ir(
3282
+ expr.comparators[0],
3283
+ model_converter=model_converter,
3284
+ enum_resolver=enum_resolver,
3285
+ )
3286
+ if op and right:
3287
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
3288
+ return result
3289
+
3290
+ if isinstance(expr, ast.BoolOp):
3291
+ values = [
3292
+ _expr_to_ir(v, model_converter=model_converter, enum_resolver=enum_resolver)
3293
+ for v in expr.values
3294
+ ]
3295
+ if all(v for v in values):
3296
+ op = _bool_op_to_ir(expr.op)
3297
+ if op and len(values) >= 2:
3298
+ # Chain boolean ops: a and b and c -> (a and b) and c
3299
+ result_expr = values[0]
3300
+ for v in values[1:]:
3301
+ if result_expr and v:
3302
+ new_result = ir.Expr()
3303
+ new_result.binary_op.CopyFrom(ir.BinaryOp(left=result_expr, op=op, right=v))
3304
+ result_expr = new_result
3305
+ return result_expr
3306
+
3307
+ if isinstance(expr, ast.List):
3308
+ elements = [
3309
+ _expr_to_ir(e, model_converter=model_converter, enum_resolver=enum_resolver)
3310
+ for e in expr.elts
3311
+ ]
3312
+ if all(e for e in elements):
3313
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
3314
+ result.list.CopyFrom(list_expr)
3315
+ return result
3316
+
3317
+ if isinstance(expr, ast.Dict):
3318
+ entries: List[ir.DictEntry] = []
3319
+ for k, v in zip(expr.keys, expr.values, strict=False):
3320
+ if k:
3321
+ key_expr = _expr_to_ir(
3322
+ k,
3323
+ model_converter=model_converter,
3324
+ enum_resolver=enum_resolver,
3325
+ )
3326
+ value_expr = _expr_to_ir(
3327
+ v,
3328
+ model_converter=model_converter,
3329
+ enum_resolver=enum_resolver,
3330
+ )
3331
+ if key_expr and value_expr:
3332
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
3333
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
3334
+ return result
3335
+
3336
+ if isinstance(expr, ast.Subscript):
3337
+ obj = _expr_to_ir(
3338
+ expr.value,
3339
+ model_converter=model_converter,
3340
+ enum_resolver=enum_resolver,
3341
+ )
3342
+ index = (
3343
+ _expr_to_ir(
3344
+ expr.slice,
3345
+ model_converter=model_converter,
3346
+ enum_resolver=enum_resolver,
3347
+ )
3348
+ if isinstance(expr.slice, ast.AST)
3349
+ else None
3350
+ )
3351
+ if obj and index:
3352
+ result.index.CopyFrom(ir.IndexAccess(object=obj, index=index))
3353
+ return result
3354
+
3355
+ if isinstance(expr, ast.Attribute):
3356
+ if enum_resolver:
3357
+ resolved = enum_resolver(expr)
3358
+ if resolved:
3359
+ return resolved
3360
+ obj = _expr_to_ir(
3361
+ expr.value,
3362
+ model_converter=model_converter,
3363
+ enum_resolver=enum_resolver,
3364
+ )
3365
+ if obj:
3366
+ result.dot.CopyFrom(ir.DotAccess(object=obj, attribute=expr.attr))
3367
+ return result
3368
+
3369
+ if isinstance(expr, ast.Await) and isinstance(expr.value, ast.Call):
3370
+ func_name = _get_func_name(expr.value.func)
3371
+ if func_name:
3372
+ args = [
3373
+ _expr_to_ir(
3374
+ a,
3375
+ model_converter=model_converter,
3376
+ enum_resolver=enum_resolver,
3377
+ )
3378
+ for a in expr.value.args
3379
+ ]
3380
+ kwargs: List[ir.Kwarg] = []
3381
+ for kw in expr.value.keywords:
3382
+ if kw.arg:
3383
+ kw_expr = _expr_to_ir(
3384
+ kw.value,
3385
+ model_converter=model_converter,
3386
+ enum_resolver=enum_resolver,
3387
+ )
3388
+ if kw_expr:
3389
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
3390
+ func_call = ir.FunctionCall(
3391
+ name=func_name,
3392
+ args=[a for a in args if a],
3393
+ kwargs=kwargs,
3394
+ )
3395
+ global_function = _global_function_for_call(func_name, expr.value)
3396
+ if global_function is not None:
3397
+ func_call.global_function = global_function
3398
+ result.function_call.CopyFrom(func_call)
3399
+ return result
3400
+
3401
+ if isinstance(expr, ast.Call):
3402
+ # Function call
3403
+ if not _is_self_method_call(expr):
3404
+ func_name = _get_func_name(expr.func) or "unknown"
3405
+ if isinstance(expr.func, ast.Attribute):
3406
+ line = expr.lineno if hasattr(expr, "lineno") else None
3407
+ col = expr.col_offset if hasattr(expr, "col_offset") else None
3408
+ raise UnsupportedPatternError(
3409
+ f"Calling synchronous function '{func_name}()' directly is not supported",
3410
+ RECOMMENDATIONS["sync_function_call"],
3411
+ line=line,
3412
+ col=col,
3413
+ )
3414
+ if func_name not in ALLOWED_SYNC_FUNCTIONS:
3415
+ line = expr.lineno if hasattr(expr, "lineno") else None
3416
+ col = expr.col_offset if hasattr(expr, "col_offset") else None
3417
+ raise UnsupportedPatternError(
3418
+ f"Calling synchronous function '{func_name}()' directly is not supported",
3419
+ RECOMMENDATIONS["sync_function_call"],
3420
+ line=line,
3421
+ col=col,
3422
+ )
3423
+ func_name = _get_func_name(expr.func)
3424
+ if func_name:
3425
+ args = [
3426
+ _expr_to_ir(
3427
+ a,
3428
+ model_converter=model_converter,
3429
+ enum_resolver=enum_resolver,
3430
+ )
3431
+ for a in expr.args
3432
+ ]
3433
+ kwargs: List[ir.Kwarg] = []
3434
+ for kw in expr.keywords:
3435
+ if kw.arg:
3436
+ kw_expr = _expr_to_ir(
3437
+ kw.value,
3438
+ model_converter=model_converter,
3439
+ enum_resolver=enum_resolver,
3440
+ )
3441
+ if kw_expr:
3442
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
3443
+ func_call = ir.FunctionCall(
3444
+ name=func_name,
3445
+ args=[a for a in args if a],
3446
+ kwargs=kwargs,
3447
+ )
3448
+ global_function = _global_function_for_call(func_name, expr)
3449
+ if global_function is not None:
3450
+ func_call.global_function = global_function
3451
+ result.function_call.CopyFrom(func_call)
3452
+ return result
3453
+
3454
+ if isinstance(expr, ast.Tuple):
3455
+ # Handle tuple as list for now
3456
+ elements = [
3457
+ _expr_to_ir(e, model_converter=model_converter, enum_resolver=enum_resolver)
3458
+ for e in expr.elts
3459
+ ]
3460
+ if all(e for e in elements):
3461
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
3462
+ result.list.CopyFrom(list_expr)
3463
+ return result
3464
+
3465
+ # Check for unsupported expression types
3466
+ _check_unsupported_expression(expr)
3467
+
3468
+ return None
3469
+
3470
+
3471
+ def _check_unsupported_expression(expr: ast.AST) -> None:
3472
+ """Check for unsupported expression types and raise descriptive errors."""
3473
+ line = getattr(expr, "lineno", None)
3474
+ col = getattr(expr, "col_offset", None)
3475
+
3476
+ if isinstance(expr, ast.Constant):
3477
+ if _constant_to_literal(expr.value) is None:
3478
+ raise UnsupportedPatternError(
3479
+ f"Unsupported literal type '{type(expr.value).__name__}'",
3480
+ RECOMMENDATIONS["unsupported_literal"],
3481
+ line=line,
3482
+ col=col,
3483
+ )
3484
+
3485
+ if isinstance(expr, ast.JoinedStr):
3486
+ raise UnsupportedPatternError(
3487
+ "F-strings are not supported",
3488
+ RECOMMENDATIONS["fstring"],
3489
+ line=line,
3490
+ col=col,
3491
+ )
3492
+ elif isinstance(expr, ast.Lambda):
3493
+ raise UnsupportedPatternError(
3494
+ "Lambda expressions are not supported",
3495
+ RECOMMENDATIONS["lambda"],
3496
+ line=line,
3497
+ col=col,
3498
+ )
3499
+ elif isinstance(expr, ast.ListComp):
3500
+ raise UnsupportedPatternError(
3501
+ "List comprehensions are not supported in this context",
3502
+ RECOMMENDATIONS["list_comprehension"],
3503
+ line=line,
3504
+ col=col,
3505
+ )
3506
+ elif isinstance(expr, ast.DictComp):
3507
+ raise UnsupportedPatternError(
3508
+ "Dict comprehensions are not supported in this context",
3509
+ RECOMMENDATIONS["dict_comprehension"],
3510
+ line=line,
3511
+ col=col,
3512
+ )
3513
+ elif isinstance(expr, ast.SetComp):
3514
+ raise UnsupportedPatternError(
3515
+ "Set comprehensions are not supported",
3516
+ RECOMMENDATIONS["set_comprehension"],
3517
+ line=line,
3518
+ col=col,
3519
+ )
3520
+ elif isinstance(expr, ast.GeneratorExp):
3521
+ raise UnsupportedPatternError(
3522
+ "Generator expressions are not supported",
3523
+ RECOMMENDATIONS["generator"],
3524
+ line=line,
3525
+ col=col,
3526
+ )
3527
+ elif isinstance(expr, ast.NamedExpr):
3528
+ raise UnsupportedPatternError(
3529
+ "The walrus operator (:=) is not supported",
3530
+ RECOMMENDATIONS["walrus"],
3531
+ line=line,
3532
+ col=col,
3533
+ )
3534
+ elif isinstance(expr, ast.Yield) or isinstance(expr, ast.YieldFrom):
3535
+ raise UnsupportedPatternError(
3536
+ "Yield expressions are not supported",
3537
+ RECOMMENDATIONS["yield_statement"],
3538
+ line=line,
3539
+ col=col,
3540
+ )
3541
+ elif isinstance(expr, ast.expr):
3542
+ raise UnsupportedPatternError(
3543
+ f"Unsupported expression type '{type(expr).__name__}'",
3544
+ RECOMMENDATIONS["unsupported_expression"],
3545
+ line=line,
3546
+ col=col,
3547
+ )
3548
+
3549
+
3550
+ def _format_subscript_target(target: ast.Subscript) -> Optional[str]:
3551
+ """Convert a subscript target to a string representation for tracking."""
3552
+ if not isinstance(target.value, ast.Name):
3553
+ return None
3554
+
3555
+ base = target.value.id
3556
+ try:
3557
+ index_str = ast.unparse(target.slice)
3558
+ except Exception:
3559
+ return None
3560
+
3561
+ return f"{base}[{index_str}]"
3562
+
3563
+
3564
+ def _constant_to_literal(value: Any) -> Optional[ir.Literal]:
3565
+ """Convert a Python constant to IR Literal."""
3566
+ literal = ir.Literal()
3567
+ if value is None:
3568
+ literal.is_none = True
3569
+ elif isinstance(value, bool):
3570
+ literal.bool_value = value
3571
+ elif isinstance(value, int):
3572
+ literal.int_value = value
3573
+ elif isinstance(value, float):
3574
+ literal.float_value = value
3575
+ elif isinstance(value, str):
3576
+ literal.string_value = value
3577
+ else:
3578
+ return None
3579
+ return literal
3580
+
3581
+
3582
+ def _bin_op_to_ir(op: ast.operator) -> Optional[ir.BinaryOperator]:
3583
+ """Convert Python binary operator to IR BinaryOperator."""
3584
+ mapping = {
3585
+ ast.Add: ir.BinaryOperator.BINARY_OP_ADD,
3586
+ ast.Sub: ir.BinaryOperator.BINARY_OP_SUB,
3587
+ ast.Mult: ir.BinaryOperator.BINARY_OP_MUL,
3588
+ ast.Div: ir.BinaryOperator.BINARY_OP_DIV,
3589
+ ast.FloorDiv: ir.BinaryOperator.BINARY_OP_FLOOR_DIV,
3590
+ ast.Mod: ir.BinaryOperator.BINARY_OP_MOD,
3591
+ }
3592
+ return mapping.get(type(op))
3593
+
3594
+
3595
+ def _unary_op_to_ir(op: ast.unaryop) -> Optional[ir.UnaryOperator]:
3596
+ """Convert Python unary operator to IR UnaryOperator."""
3597
+ mapping = {
3598
+ ast.USub: ir.UnaryOperator.UNARY_OP_NEG,
3599
+ ast.Not: ir.UnaryOperator.UNARY_OP_NOT,
3600
+ }
3601
+ return mapping.get(type(op))
3602
+
3603
+
3604
+ def _cmp_op_to_ir(op: ast.cmpop) -> Optional[ir.BinaryOperator]:
3605
+ """Convert Python comparison operator to IR BinaryOperator."""
3606
+ mapping = {
3607
+ ast.Eq: ir.BinaryOperator.BINARY_OP_EQ,
3608
+ ast.NotEq: ir.BinaryOperator.BINARY_OP_NE,
3609
+ ast.Lt: ir.BinaryOperator.BINARY_OP_LT,
3610
+ ast.LtE: ir.BinaryOperator.BINARY_OP_LE,
3611
+ ast.Gt: ir.BinaryOperator.BINARY_OP_GT,
3612
+ ast.GtE: ir.BinaryOperator.BINARY_OP_GE,
3613
+ ast.In: ir.BinaryOperator.BINARY_OP_IN,
3614
+ ast.NotIn: ir.BinaryOperator.BINARY_OP_NOT_IN,
3615
+ }
3616
+ return mapping.get(type(op))
3617
+
3618
+
3619
+ def _bool_op_to_ir(op: ast.boolop) -> Optional[ir.BinaryOperator]:
3620
+ """Convert Python boolean operator to IR BinaryOperator."""
3621
+ mapping = {
3622
+ ast.And: ir.BinaryOperator.BINARY_OP_AND,
3623
+ ast.Or: ir.BinaryOperator.BINARY_OP_OR,
3624
+ }
3625
+ return mapping.get(type(op))
3626
+
3627
+
3628
+ def _get_func_name(func: ast.expr) -> Optional[str]:
3629
+ """Get function name from a Call's func attribute."""
3630
+ if isinstance(func, ast.Name):
3631
+ return func.id
3632
+ elif isinstance(func, ast.Attribute):
3633
+ # For method calls like obj.method, return full dotted name
3634
+ parts = []
3635
+ current = func
3636
+ while isinstance(current, ast.Attribute):
3637
+ parts.append(current.attr)
3638
+ current = current.value
3639
+ if isinstance(current, ast.Name):
3640
+ parts.append(current.id)
3641
+ name = ".".join(reversed(parts))
3642
+ if name.startswith("self."):
3643
+ return name[5:]
3644
+ return name
3645
+ return None
3646
+
3647
+
3648
+ def _is_self_method_call(node: ast.Call) -> bool:
3649
+ """Return True if the call is a direct self.method(...) invocation."""
3650
+ func = node.func
3651
+ return (
3652
+ isinstance(func, ast.Attribute)
3653
+ and isinstance(func.value, ast.Name)
3654
+ and func.value.id == "self"
3655
+ )
3656
+
3657
+
3658
+ def _global_function_for_call(
3659
+ func_name: str, node: ast.Call
3660
+ ) -> Optional[ir.GlobalFunction.ValueType]:
3661
+ """Return the GlobalFunction enum value for supported globals."""
3662
+ if _is_self_method_call(node):
3663
+ return None
3664
+ return GLOBAL_FUNCTIONS.get(func_name)