rappel 0.10.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of rappel might be problematic. Click here for more details.

rappel/ir_builder.py ADDED
@@ -0,0 +1,3945 @@
1
+ """
2
+ IR Builder - Converts Python workflow AST to Rappel IR (ast.proto).
3
+
4
+ This module parses Python workflow classes and produces the IR representation
5
+ that can be sent to the Rust runtime for execution.
6
+
7
+ The IR builder performs deep transformations to convert Python patterns into
8
+ valid Rappel IR structures. Control flow bodies (try, for, while, if branches)
9
+ are represented as full blocks, so they can contain arbitrary statements and
10
+ `return` behaves consistently with Python (returns from the entry function end
11
+ the workflow).
12
+
13
+ Validation:
14
+ The IR builder proactively detects unsupported Python patterns and raises
15
+ UnsupportedPatternError with clear recommendations for how to rewrite the code.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import ast
21
+ import copy
22
+ import inspect
23
+ import textwrap
24
+ from dataclasses import dataclass
25
+ from enum import EnumMeta
26
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, NoReturn, Optional, Set, Union
27
+
28
+ from proto import ast_pb2 as ir
29
+ from rappel.registry import registry
30
+
31
+
32
+ class UnsupportedPatternError(Exception):
33
+ """Raised when the IR builder encounters an unsupported Python pattern.
34
+
35
+ This error includes a recommendation for how to rewrite the code to use
36
+ supported patterns.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ message: str,
42
+ recommendation: str,
43
+ line: Optional[int] = None,
44
+ col: Optional[int] = None,
45
+ filename: Optional[str] = None,
46
+ ):
47
+ self.message = message
48
+ self.recommendation = recommendation
49
+ self.line = line
50
+ self.col = col
51
+ self.filename = filename
52
+
53
+ location_parts: List[str] = []
54
+ if filename:
55
+ location_parts.append(filename)
56
+ if line:
57
+ location_parts.append(f"line {line}")
58
+ if col is not None:
59
+ location_parts.append(f"col {col}")
60
+ location = f" ({', '.join(location_parts)})" if location_parts else ""
61
+ full_message = f"{message}{location}\n\nRecommendation: {recommendation}"
62
+ super().__init__(full_message)
63
+
64
+
65
+ # Recommendations for common unsupported patterns
66
+ RECOMMENDATIONS = {
67
+ "constructor_return": (
68
+ "Returning a class constructor (like MyModel(...)) directly is not supported.\n"
69
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
70
+ "Use an @action to create the object:\n\n"
71
+ " @action\n"
72
+ " async def build_result(items: list, count: int) -> MyResult:\n"
73
+ " return MyResult(items=items, count=count)\n\n"
74
+ " # In workflow:\n"
75
+ " return await build_result(items, count)"
76
+ ),
77
+ "constructor_assignment": (
78
+ "Assigning a class constructor result (like x = MyClass(...)) is not supported.\n"
79
+ "The workflow IR cannot serialize arbitrary object instantiation.\n\n"
80
+ "Use an @action to create the object:\n\n"
81
+ " @action\n"
82
+ " async def create_config(value: int) -> Config:\n"
83
+ " return Config(value=value)\n\n"
84
+ " # In workflow:\n"
85
+ " config = await create_config(value)"
86
+ ),
87
+ "non_action_call": (
88
+ "Calling a function that is not decorated with @action is not supported.\n"
89
+ "Only @action decorated functions can be awaited in workflow code.\n\n"
90
+ "Add the @action decorator to your function:\n\n"
91
+ " @action\n"
92
+ " async def my_function(x: int) -> int:\n"
93
+ " return x * 2"
94
+ ),
95
+ "sync_function_call": (
96
+ "Calling a synchronous function directly in workflow code is not supported.\n"
97
+ "All computation must happen inside @action decorated async functions.\n\n"
98
+ "Wrap your logic in an @action:\n\n"
99
+ " @action\n"
100
+ " async def compute(x: int) -> int:\n"
101
+ " return some_sync_function(x)"
102
+ ),
103
+ "method_call_non_self": (
104
+ "Calling methods on objects other than 'self' is not supported in workflow code.\n"
105
+ "Use an @action to perform method calls:\n\n"
106
+ " @action\n"
107
+ " async def call_method(obj: MyClass) -> Result:\n"
108
+ " return obj.some_method()"
109
+ ),
110
+ "builtin_call": (
111
+ "Calling built-in functions like len(), str(), int() directly is not supported.\n"
112
+ "Use an @action to perform these operations:\n\n"
113
+ " @action\n"
114
+ " async def get_length(items: list) -> int:\n"
115
+ " return len(items)"
116
+ ),
117
+ "fstring": (
118
+ "F-strings are not supported in workflow code because they require "
119
+ "runtime string interpolation.\n"
120
+ "Use an @action to perform string formatting:\n\n"
121
+ " @action\n"
122
+ " async def format_message(value: int) -> str:\n"
123
+ " return f'Result: {value}'"
124
+ ),
125
+ "delete": (
126
+ "The 'del' statement is not supported in workflow code.\n"
127
+ "Use an @action to perform mutations:\n\n"
128
+ " @action\n"
129
+ " async def remove_key(data: dict, key: str) -> dict:\n"
130
+ " del data[key]\n"
131
+ " return data"
132
+ ),
133
+ "while_loop": (
134
+ "While loops are not supported in workflow code because they can run "
135
+ "indefinitely.\n"
136
+ "Use a for loop with a fixed range, or restructure as recursive workflow calls."
137
+ ),
138
+ "with_statement": (
139
+ "Context managers (with statements) are not supported in workflow code.\n"
140
+ "Use an @action to handle resource management:\n\n"
141
+ " @action\n"
142
+ " async def read_file(path: str) -> str:\n"
143
+ " with open(path) as f:\n"
144
+ " return f.read()"
145
+ ),
146
+ "raise_statement": (
147
+ "The 'raise' statement is not supported directly in workflow code.\n"
148
+ "Use an @action that raises exceptions, or return error values."
149
+ ),
150
+ "assert_statement": (
151
+ "Assert statements are not supported in workflow code.\n"
152
+ "Use an @action for validation, or use if statements with explicit error handling."
153
+ ),
154
+ "lambda": (
155
+ "Lambda expressions are not supported in workflow code.\n"
156
+ "Use an @action to define the function logic."
157
+ ),
158
+ "list_comprehension": (
159
+ "List comprehensions are only supported when assigned directly to a variable "
160
+ "or inside asyncio.gather(*[...]).\n"
161
+ "For other cases, use a for loop or an @action."
162
+ ),
163
+ "dict_comprehension": (
164
+ "Dict comprehensions are only supported when assigned directly to a variable.\n"
165
+ "For other cases, use a for loop or an @action:\n\n"
166
+ " @action\n"
167
+ " async def build_dict(items: list) -> dict:\n"
168
+ " return {k: v for k, v in items}"
169
+ ),
170
+ "set_comprehension": (
171
+ "Set comprehensions are not supported in workflow code.\nUse an @action to build sets."
172
+ ),
173
+ "generator": (
174
+ "Generator expressions are not supported in workflow code.\n"
175
+ "Use a list or an @action instead."
176
+ ),
177
+ "walrus": (
178
+ "The walrus operator (:=) is not supported in workflow code.\n"
179
+ "Use separate assignment statements instead."
180
+ ),
181
+ "match": (
182
+ "Match statements are not supported in workflow code.\nUse if/elif/else chains instead."
183
+ ),
184
+ "gather_return_exceptions": (
185
+ "asyncio.gather() must be called with return_exceptions=True so results and errors "
186
+ "are represented consistently in Rappel's IR.\n\n"
187
+ "Update your call:\n\n"
188
+ " results = await asyncio.gather(task_a(), task_b(), return_exceptions=True)"
189
+ ),
190
+ "gather_variable_spread": (
191
+ "Spreading a variable in asyncio.gather() is not supported because it requires "
192
+ "data flow analysis to determine the contents.\n"
193
+ "Use a list comprehension directly in gather:\n\n"
194
+ " # Instead of:\n"
195
+ " tasks = []\n"
196
+ " for i in range(count):\n"
197
+ " tasks.append(process(value=i))\n"
198
+ " results = await asyncio.gather(*tasks, return_exceptions=True)\n\n"
199
+ " # Use:\n"
200
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)], "
201
+ "return_exceptions=True)"
202
+ ),
203
+ "for_loop_append_pattern": (
204
+ "Building a task list in a for loop then spreading in asyncio.gather() is not "
205
+ "supported.\n"
206
+ "Use a list comprehension directly in gather:\n\n"
207
+ " # Instead of:\n"
208
+ " tasks = []\n"
209
+ " for i in range(count):\n"
210
+ " tasks.append(process(value=i))\n"
211
+ " results = await asyncio.gather(*tasks, return_exceptions=True)\n\n"
212
+ " # Use:\n"
213
+ " results = await asyncio.gather(*[process(value=i) for i in range(count)], "
214
+ "return_exceptions=True)"
215
+ ),
216
+ "global_statement": (
217
+ "Global statements are not supported in workflow code.\n"
218
+ "Use workflow state or pass values explicitly."
219
+ ),
220
+ "nonlocal_statement": (
221
+ "Nonlocal statements are not supported in workflow code.\n"
222
+ "Use explicit parameter passing instead."
223
+ ),
224
+ "import_statement": (
225
+ "Import statements inside workflow run() are not supported.\n"
226
+ "Place imports at the module level."
227
+ ),
228
+ "class_def": (
229
+ "Class definitions inside workflow run() are not supported.\n"
230
+ "Define classes at the module level."
231
+ ),
232
+ "function_def": (
233
+ "Nested function definitions inside workflow run() are not supported.\n"
234
+ "Define functions at the module level or use @action."
235
+ ),
236
+ "yield_statement": (
237
+ "Yield statements are not supported in workflow code.\n"
238
+ "Workflows must return a complete result, not generate values incrementally."
239
+ ),
240
+ "continue_statement": (
241
+ "Continue statements are not supported in workflow code.\n"
242
+ "Restructure your loop using if/else to skip iterations."
243
+ ),
244
+ "unsupported_statement": (
245
+ "This statement type is not supported in workflow code.\n"
246
+ "Move the logic into an @action or rewrite using supported statements."
247
+ ),
248
+ "unsupported_expression": (
249
+ "This expression type is not supported in workflow code.\n"
250
+ "Move the logic into an @action or rewrite using supported expressions."
251
+ ),
252
+ "unsupported_literal": (
253
+ "This literal type is not supported in workflow code.\n"
254
+ "Convert the value to a supported literal type inside an @action."
255
+ ),
256
+ }
257
+
258
+ GLOBAL_FUNCTIONS = {
259
+ "enumerate": ir.GlobalFunction.GLOBAL_FUNCTION_ENUMERATE,
260
+ "isexception": ir.GlobalFunction.GLOBAL_FUNCTION_ISEXCEPTION,
261
+ "len": ir.GlobalFunction.GLOBAL_FUNCTION_LEN,
262
+ "range": ir.GlobalFunction.GLOBAL_FUNCTION_RANGE,
263
+ }
264
+ ALLOWED_SYNC_FUNCTIONS = set(GLOBAL_FUNCTIONS)
265
+
266
+ _CURRENT_ACTION_NAMES: set[str] = set()
267
+
268
+
269
+ if TYPE_CHECKING:
270
+ from .workflow import Workflow
271
+
272
+
273
+ @dataclass
274
+ class ActionDefinition:
275
+ """Definition of an action function."""
276
+
277
+ action_name: str
278
+ module_name: Optional[str]
279
+ signature: inspect.Signature
280
+
281
+
282
+ @dataclass
283
+ class FunctionParameter:
284
+ """A function parameter with optional default value."""
285
+
286
+ name: str
287
+ has_default: bool
288
+ default_value: Any = None # The actual Python value if has_default is True
289
+
290
+
291
+ @dataclass
292
+ class FunctionSignatureInfo:
293
+ """Signature info for a workflow function, used to fill in default arguments."""
294
+
295
+ parameters: List[FunctionParameter]
296
+
297
+
298
+ @dataclass
299
+ class ModuleContext:
300
+ """Cached IRBuilder context derived from a module."""
301
+
302
+ action_defs: Dict[str, ActionDefinition]
303
+ imported_names: Dict[str, "ImportedName"]
304
+ module_functions: Set[str]
305
+ model_defs: Dict[str, "ModelDefinition"]
306
+
307
+
308
+ @dataclass
309
+ class TransformContext:
310
+ """Context for IR transformations."""
311
+
312
+ # Counter for generating unique function names
313
+ implicit_fn_counter: int = 0
314
+ # Implicit functions generated during transformation
315
+ implicit_functions: List[ir.FunctionDef] = None # type: ignore
316
+ # Function signatures for workflow methods (used to fill in default arguments)
317
+ function_signatures: Dict[str, FunctionSignatureInfo] = None # type: ignore
318
+
319
+ def __post_init__(self) -> None:
320
+ if self.implicit_functions is None:
321
+ self.implicit_functions = []
322
+ if self.function_signatures is None:
323
+ self.function_signatures = {}
324
+
325
+ def next_implicit_fn_name(self, prefix: str = "implicit") -> str:
326
+ """Generate a unique implicit function name."""
327
+ self.implicit_fn_counter += 1
328
+ return f"__{prefix}_{self.implicit_fn_counter}__"
329
+
330
+
331
+ def build_workflow_ir(workflow_cls: type["Workflow"]) -> ir.Program:
332
+ """Build an IR Program from a workflow class.
333
+
334
+ Args:
335
+ workflow_cls: The workflow class to convert.
336
+
337
+ Returns:
338
+ An IR Program proto message.
339
+ """
340
+ original_run = getattr(workflow_cls, "__workflow_run_impl__", None)
341
+ if original_run is None:
342
+ original_run = workflow_cls.__dict__.get("run")
343
+ if original_run is None:
344
+ raise ValueError(f"workflow {workflow_cls!r} missing run implementation")
345
+
346
+ module = inspect.getmodule(original_run)
347
+ if module is None:
348
+ raise ValueError(f"unable to locate module for workflow {workflow_cls!r}")
349
+
350
+ module_contexts: Dict[str, ModuleContext] = {}
351
+
352
+ def get_module_context(target_module: Any) -> ModuleContext:
353
+ module_name = target_module.__name__
354
+ if module_name not in module_contexts:
355
+ module_contexts[module_name] = ModuleContext(
356
+ action_defs=_discover_action_names(target_module),
357
+ imported_names=_discover_module_imports(target_module),
358
+ module_functions=_discover_module_functions(target_module),
359
+ model_defs=_discover_model_definitions(target_module),
360
+ )
361
+ return module_contexts[module_name]
362
+
363
+ # Build the IR with transformation context
364
+ ctx = TransformContext()
365
+ program = ir.Program()
366
+ function_defs: Dict[str, ir.FunctionDef] = {}
367
+
368
+ # Extract instance attributes from __init__ for policy resolution
369
+ instance_attrs = _extract_instance_attrs(workflow_cls)
370
+
371
+ def parse_function(fn: Any) -> tuple[ast.AST, Optional[str], int]:
372
+ source_lines, start_line = inspect.getsourcelines(fn)
373
+ function_source = textwrap.dedent("".join(source_lines))
374
+ filename = inspect.getsourcefile(fn)
375
+ if filename is None:
376
+ filename = inspect.getfile(fn)
377
+ return ast.parse(function_source, filename=filename or "<unknown>"), filename, start_line
378
+
379
+ def _with_source_location(
380
+ err: UnsupportedPatternError,
381
+ filename: Optional[str],
382
+ start_line: int,
383
+ ) -> UnsupportedPatternError:
384
+ line = err.line
385
+ col = err.col
386
+ if line is not None:
387
+ line = start_line + line - 1
388
+ if col is not None:
389
+ col = col + 1
390
+ return UnsupportedPatternError(
391
+ err.message,
392
+ err.recommendation,
393
+ line=line,
394
+ col=col,
395
+ filename=filename,
396
+ )
397
+
398
+ def add_function_def(
399
+ fn: Any,
400
+ fn_tree: ast.AST,
401
+ filename: Optional[str],
402
+ start_line: int,
403
+ override_name: Optional[str] = None,
404
+ ) -> None:
405
+ global _CURRENT_ACTION_NAMES
406
+ fn_module = inspect.getmodule(fn)
407
+ if fn_module is None:
408
+ raise ValueError(f"unable to locate module for function {fn!r}")
409
+
410
+ ctx_data = get_module_context(fn_module)
411
+ _CURRENT_ACTION_NAMES = set(ctx_data.action_defs.keys())
412
+ builder = IRBuilder(
413
+ ctx_data.action_defs,
414
+ ctx,
415
+ ctx_data.imported_names,
416
+ ctx_data.module_functions,
417
+ ctx_data.model_defs,
418
+ fn_module.__dict__,
419
+ instance_attrs,
420
+ )
421
+ try:
422
+ builder.visit(fn_tree)
423
+ except UnsupportedPatternError as err:
424
+ raise _with_source_location(err, filename, start_line) from err
425
+ if builder.function_def:
426
+ if override_name:
427
+ builder.function_def.name = override_name
428
+ function_defs[builder.function_def.name] = builder.function_def
429
+
430
+ # Discover all reachable helper methods first so we can pre-collect their signatures.
431
+ # This is needed because when we process a function that calls another function,
432
+ # we need to know the callee's signature to fill in default arguments.
433
+ run_tree, run_filename, run_start_line = parse_function(original_run)
434
+
435
+ # Collect all reachable methods and their trees
436
+ methods_to_process: List[tuple[Any, ast.AST, Optional[str], int, Optional[str]]] = [
437
+ (original_run, run_tree, run_filename, run_start_line, "main")
438
+ ]
439
+ pending = list(_collect_self_method_calls(run_tree))
440
+ visited: Set[str] = set()
441
+ skip_methods = {"run_action"}
442
+
443
+ while pending:
444
+ method_name = pending.pop()
445
+ if method_name in visited or method_name == "run" or method_name in skip_methods:
446
+ continue
447
+ visited.add(method_name)
448
+
449
+ method = _find_workflow_method(workflow_cls, method_name)
450
+ if method is None:
451
+ continue
452
+
453
+ method_tree, method_filename, method_start_line = parse_function(method)
454
+ methods_to_process.append((method, method_tree, method_filename, method_start_line, None))
455
+ pending.extend(_collect_self_method_calls(method_tree))
456
+
457
+ # Pre-collect signatures for all methods before processing IR.
458
+ # This ensures that when we encounter a function call, we can fill in default args.
459
+ for fn, _fn_tree, _filename, _start_line, override_name in methods_to_process:
460
+ fn_name = override_name if override_name else fn.__name__
461
+ sig = inspect.signature(fn)
462
+ params: List[FunctionParameter] = []
463
+ for param_name, param in sig.parameters.items():
464
+ if param_name == "self":
465
+ continue
466
+ has_default = param.default is not inspect.Parameter.empty
467
+ default_value = param.default if has_default else None
468
+ params.append(
469
+ FunctionParameter(
470
+ name=param_name,
471
+ has_default=has_default,
472
+ default_value=default_value,
473
+ )
474
+ )
475
+ ctx.function_signatures[fn_name] = FunctionSignatureInfo(parameters=params)
476
+
477
+ # Now process all functions with signatures already available
478
+ for fn, fn_tree, filename, start_line, override_name in methods_to_process:
479
+ add_function_def(fn, fn_tree, filename, start_line, override_name)
480
+
481
+ # Add implicit functions first (they may be called by the main function)
482
+ for implicit_fn in ctx.implicit_functions:
483
+ program.functions.append(implicit_fn)
484
+
485
+ # Add all function definitions (run + reachable helper methods)
486
+ for fn_def in function_defs.values():
487
+ program.functions.append(fn_def)
488
+
489
+ global _CURRENT_ACTION_NAMES
490
+ _CURRENT_ACTION_NAMES = set()
491
+
492
+ return program
493
+
494
+
495
+ def _collect_self_method_calls(tree: ast.AST) -> Set[str]:
496
+ """Collect self.method(...) call names from a parsed function AST."""
497
+ calls: Set[str] = set()
498
+ for node in ast.walk(tree):
499
+ if not isinstance(node, ast.Call):
500
+ continue
501
+ func = node.func
502
+ if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
503
+ if func.value.id == "self":
504
+ calls.add(func.attr)
505
+ return calls
506
+
507
+
508
+ def _find_workflow_method(workflow_cls: type["Workflow"], name: str) -> Optional[Any]:
509
+ """Find a workflow method by name across the class MRO."""
510
+ for base in workflow_cls.__mro__:
511
+ if name not in base.__dict__:
512
+ continue
513
+ value = base.__dict__[name]
514
+ if isinstance(value, staticmethod) or isinstance(value, classmethod):
515
+ return value.__func__
516
+ if inspect.isfunction(value):
517
+ return value
518
+ return None
519
+
520
+
521
+ def _extract_instance_attrs(workflow_cls: type["Workflow"]) -> Dict[str, ast.expr]:
522
+ """Extract self.attr = value assignments from the workflow's __init__ method.
523
+
524
+ Parses the __init__ method to find assignments like:
525
+ self.retry_policy = RetryPolicy(attempts=3)
526
+ self.timeout = 30
527
+
528
+ Returns a dict mapping attribute names to their AST value nodes.
529
+ """
530
+ init_method = _find_workflow_method(workflow_cls, "__init__")
531
+ if init_method is None:
532
+ return {}
533
+
534
+ try:
535
+ source_lines, _ = inspect.getsourcelines(init_method)
536
+ source = textwrap.dedent("".join(source_lines))
537
+ tree = ast.parse(source)
538
+ except (OSError, TypeError, SyntaxError):
539
+ return {}
540
+
541
+ attrs: Dict[str, ast.expr] = {}
542
+
543
+ # Walk the __init__ body looking for self.attr = value assignments
544
+ for node in ast.walk(tree):
545
+ if not isinstance(node, ast.Assign):
546
+ continue
547
+ # Only handle single-target assignments
548
+ if len(node.targets) != 1:
549
+ continue
550
+ target = node.targets[0]
551
+ # Check for self.attr pattern
552
+ if (
553
+ isinstance(target, ast.Attribute)
554
+ and isinstance(target.value, ast.Name)
555
+ and target.value.id == "self"
556
+ ):
557
+ attrs[target.attr] = node.value
558
+
559
+ return attrs
560
+
561
+
562
+ def _discover_action_names(module: Any) -> Dict[str, ActionDefinition]:
563
+ """Discover all @action decorated functions in a module."""
564
+ names: Dict[str, ActionDefinition] = {}
565
+ for attr_name in dir(module):
566
+ attr = getattr(module, attr_name)
567
+ action_name = getattr(attr, "__rappel_action_name__", None)
568
+ action_module = getattr(attr, "__rappel_action_module__", None)
569
+ if callable(attr) and action_name:
570
+ signature = inspect.signature(attr)
571
+ names[attr_name] = ActionDefinition(
572
+ action_name=action_name,
573
+ module_name=action_module or module.__name__,
574
+ signature=signature,
575
+ )
576
+ for entry in registry.entries():
577
+ if entry.module != module.__name__:
578
+ continue
579
+ func_name = entry.func.__name__
580
+ if func_name in names:
581
+ continue
582
+ signature = inspect.signature(entry.func)
583
+ names[func_name] = ActionDefinition(
584
+ action_name=entry.name,
585
+ module_name=entry.module,
586
+ signature=signature,
587
+ )
588
+ return names
589
+
590
+
591
+ def _discover_module_functions(module: Any) -> Set[str]:
592
+ """Discover all async function names defined in a module.
593
+
594
+ This is used to detect when users await functions in the same module
595
+ that are NOT decorated with @action.
596
+ """
597
+ function_names: Set[str] = set()
598
+ for attr_name in dir(module):
599
+ try:
600
+ attr = getattr(module, attr_name)
601
+ except AttributeError:
602
+ continue
603
+
604
+ # Only include functions defined in THIS module (not imported)
605
+ if not callable(attr):
606
+ continue
607
+ if not inspect.iscoroutinefunction(attr):
608
+ continue
609
+
610
+ # Check if the function is defined in this module
611
+ func_module = getattr(attr, "__module__", None)
612
+ if func_module == module.__name__:
613
+ function_names.add(attr_name)
614
+
615
+ return function_names
616
+
617
+
618
+ @dataclass
619
+ class ImportedName:
620
+ """Tracks an imported name and its source module."""
621
+
622
+ local_name: str # Name used in code (e.g., "sleep")
623
+ module: str # Source module (e.g., "asyncio")
624
+ original_name: str # Original name in source module (e.g., "sleep")
625
+
626
+
627
+ @dataclass
628
+ class ModelFieldDefinition:
629
+ """Definition of a field in a Pydantic model or dataclass."""
630
+
631
+ name: str
632
+ has_default: bool
633
+ default_value: Any = None # Only set if has_default is True
634
+
635
+
636
+ @dataclass
637
+ class ModelDefinition:
638
+ """Definition of a Pydantic model or dataclass that can be used in workflows.
639
+
640
+ These are data classes that can be instantiated in workflow code and will
641
+ be converted to dictionary expressions in the IR.
642
+ """
643
+
644
+ class_name: str
645
+ fields: Dict[str, ModelFieldDefinition]
646
+ is_pydantic: bool # True for Pydantic models, False for dataclasses
647
+
648
+
649
+ def _is_simple_pydantic_model(cls: type) -> bool:
650
+ """Check if a class is a simple Pydantic model without custom validators.
651
+
652
+ A simple Pydantic model:
653
+ - Inherits from pydantic.BaseModel
654
+ - Has no field_validator or model_validator decorators
655
+ - Has no custom __init__ method
656
+
657
+ Returns False if pydantic is not installed or cls is not a Pydantic model.
658
+ """
659
+ try:
660
+ from pydantic import BaseModel
661
+ except ImportError:
662
+ return False
663
+
664
+ if not isinstance(cls, type) or not issubclass(cls, BaseModel):
665
+ return False
666
+
667
+ # Check for validators - Pydantic v2 uses __pydantic_decorators__
668
+ decorators = getattr(cls, "__pydantic_decorators__", None)
669
+ if decorators is not None:
670
+ # Check for field validators
671
+ if hasattr(decorators, "field_validators") and decorators.field_validators:
672
+ return False
673
+ # Check for model validators
674
+ if hasattr(decorators, "model_validators") and decorators.model_validators:
675
+ return False
676
+
677
+ # Check for custom __init__ (not the one from BaseModel)
678
+ if "__init__" in cls.__dict__:
679
+ return False
680
+
681
+ return True
682
+
683
+
684
+ def _is_simple_dataclass(cls: type) -> bool:
685
+ """Check if a class is a simple dataclass without custom logic.
686
+
687
+ A simple dataclass:
688
+ - Is decorated with @dataclass
689
+ - Has no custom __init__ method (uses the generated one)
690
+ - Has no __post_init__ method
691
+
692
+ Returns False if cls is not a dataclass.
693
+ """
694
+ import dataclasses
695
+
696
+ if not dataclasses.is_dataclass(cls):
697
+ return False
698
+
699
+ # Check for __post_init__ which could have custom logic
700
+ if hasattr(cls, "__post_init__") and "__post_init__" in cls.__dict__:
701
+ return False
702
+
703
+ # Dataclasses generate __init__, so check if there's a custom one
704
+ # that overrides it (unlikely but possible)
705
+ # The dataclass decorator sets __init__ so we can't easily detect override
706
+ # We'll trust that dataclasses without __post_init__ are simple
707
+
708
+ return True
709
+
710
+
711
+ def _get_pydantic_model_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
712
+ """Extract field definitions from a Pydantic model."""
713
+ try:
714
+ from pydantic import BaseModel
715
+ from pydantic.fields import FieldInfo
716
+ except ImportError:
717
+ return {}
718
+
719
+ if not issubclass(cls, BaseModel):
720
+ return {}
721
+
722
+ fields: Dict[str, ModelFieldDefinition] = {}
723
+
724
+ # Pydantic v2 uses model_fields
725
+ model_fields = getattr(cls, "model_fields", {})
726
+ for field_name, field_info in model_fields.items():
727
+ has_default = False
728
+ default_value = None
729
+
730
+ if isinstance(field_info, FieldInfo):
731
+ # Check if field has a default value
732
+ # PydanticUndefined means no default
733
+ from pydantic_core import PydanticUndefined
734
+
735
+ if field_info.default is not PydanticUndefined:
736
+ has_default = True
737
+ default_value = field_info.default
738
+ elif field_info.default_factory is not None:
739
+ # We can't serialize factory functions, so treat as no default
740
+ has_default = False
741
+
742
+ fields[field_name] = ModelFieldDefinition(
743
+ name=field_name,
744
+ has_default=has_default,
745
+ default_value=default_value,
746
+ )
747
+
748
+ return fields
749
+
750
+
751
+ def _get_dataclass_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
752
+ """Extract field definitions from a dataclass."""
753
+ import dataclasses
754
+
755
+ if not dataclasses.is_dataclass(cls):
756
+ return {}
757
+
758
+ fields: Dict[str, ModelFieldDefinition] = {}
759
+ for field in dataclasses.fields(cls):
760
+ has_default = False
761
+ default_value = None
762
+
763
+ if field.default is not dataclasses.MISSING:
764
+ has_default = True
765
+ default_value = field.default
766
+ elif field.default_factory is not dataclasses.MISSING:
767
+ # We can't serialize factory functions, so treat as no default
768
+ has_default = False
769
+
770
+ fields[field.name] = ModelFieldDefinition(
771
+ name=field.name,
772
+ has_default=has_default,
773
+ default_value=default_value,
774
+ )
775
+
776
+ return fields
777
+
778
+
779
+ def _discover_model_definitions(module: Any) -> Dict[str, ModelDefinition]:
780
+ """Discover all Pydantic models and dataclasses that can be used in workflows.
781
+
782
+ Only discovers "simple" models without custom validators or __post_init__.
783
+ """
784
+ models: Dict[str, ModelDefinition] = {}
785
+
786
+ for attr_name in dir(module):
787
+ try:
788
+ attr = getattr(module, attr_name)
789
+ except AttributeError:
790
+ continue
791
+
792
+ if not isinstance(attr, type):
793
+ continue
794
+
795
+ # Check if this class is defined in this module or imported
796
+ # We want to include both, as models might be imported
797
+ if _is_simple_pydantic_model(attr):
798
+ fields = _get_pydantic_model_fields(attr)
799
+ models[attr_name] = ModelDefinition(
800
+ class_name=attr_name,
801
+ fields=fields,
802
+ is_pydantic=True,
803
+ )
804
+ elif _is_simple_dataclass(attr):
805
+ fields = _get_dataclass_fields(attr)
806
+ models[attr_name] = ModelDefinition(
807
+ class_name=attr_name,
808
+ fields=fields,
809
+ is_pydantic=False,
810
+ )
811
+
812
+ return models
813
+
814
+
815
+ def _discover_module_imports(module: Any) -> Dict[str, ImportedName]:
816
+ """Discover imports in a module by parsing its source.
817
+
818
+ Tracks imports like:
819
+ - from asyncio import sleep -> {"sleep": ImportedName("sleep", "asyncio", "sleep")}
820
+ - from asyncio import sleep as s -> {"s": ImportedName("s", "asyncio", "sleep")}
821
+ """
822
+ imported: Dict[str, ImportedName] = {}
823
+
824
+ try:
825
+ source = inspect.getsource(module)
826
+ tree = ast.parse(source)
827
+ except (OSError, TypeError):
828
+ # Can't get source (e.g., built-in module)
829
+ return imported
830
+
831
+ for node in ast.walk(tree):
832
+ if isinstance(node, ast.ImportFrom) and node.module:
833
+ for alias in node.names:
834
+ local_name = alias.asname if alias.asname else alias.name
835
+ imported[local_name] = ImportedName(
836
+ local_name=local_name,
837
+ module=node.module,
838
+ original_name=alias.name,
839
+ )
840
+
841
+ return imported
842
+
843
+
844
+ class IRBuilder(ast.NodeVisitor):
845
+ """Builds IR from Python AST with deep transformations."""
846
+
847
+ def __init__(
848
+ self,
849
+ action_defs: Dict[str, ActionDefinition],
850
+ ctx: TransformContext,
851
+ imported_names: Optional[Dict[str, ImportedName]] = None,
852
+ module_functions: Optional[Set[str]] = None,
853
+ model_defs: Optional[Dict[str, ModelDefinition]] = None,
854
+ module_globals: Optional[Mapping[str, Any]] = None,
855
+ instance_attrs: Optional[Dict[str, ast.expr]] = None,
856
+ ):
857
+ self._action_defs = action_defs
858
+ self._ctx = ctx
859
+ self._imported_names = imported_names or {}
860
+ self._module_functions = module_functions or set()
861
+ self._model_defs = model_defs or {}
862
+ self._module_globals = module_globals or {}
863
+ self._instance_attrs = instance_attrs or {}
864
+ self.function_def: Optional[ir.FunctionDef] = None
865
+ self._statements: List[ir.Statement] = []
866
+
867
+ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
868
+ """Visit a function definition (the workflow's run method)."""
869
+ inputs = self._collect_function_inputs(node)
870
+
871
+ # Create the function definition
872
+ self.function_def = ir.FunctionDef(
873
+ name=node.name,
874
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
875
+ span=_make_span(node),
876
+ )
877
+
878
+ # Visit the body - _visit_statement now returns a list
879
+ self._statements = []
880
+ for stmt in node.body:
881
+ ir_stmts = self._visit_statement(stmt)
882
+ self._statements.extend(ir_stmts)
883
+
884
+ # Set the body
885
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
886
+
887
+ def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
888
+ """Visit an async function definition (the workflow's run method)."""
889
+ # Handle async the same way as sync for IR building
890
+ inputs = self._collect_function_inputs(node)
891
+
892
+ self.function_def = ir.FunctionDef(
893
+ name=node.name,
894
+ io=ir.IoDecl(inputs=inputs, outputs=[]),
895
+ span=_make_span(node),
896
+ )
897
+
898
+ self._statements = []
899
+ for stmt in node.body:
900
+ ir_stmts = self._visit_statement(stmt)
901
+ self._statements.extend(ir_stmts)
902
+
903
+ self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
904
+
905
+ def _visit_statement(self, node: ast.stmt) -> List[ir.Statement]:
906
+ """Convert a Python statement to IR Statement(s).
907
+
908
+ Returns a list because some transformations (like try block hoisting)
909
+ may expand a single Python statement into multiple IR statements.
910
+
911
+ Raises UnsupportedPatternError for unsupported statement types.
912
+ """
913
+ if isinstance(node, ast.Assign):
914
+ dict_expanded = self._expand_dict_comprehension_assignment(node)
915
+ if dict_expanded is not None:
916
+ return dict_expanded
917
+ expanded = self._expand_list_comprehension_assignment(node)
918
+ if expanded is not None:
919
+ return expanded
920
+ result = self._visit_assign(node)
921
+ return [result] if result else []
922
+ elif isinstance(node, ast.AnnAssign):
923
+ result = self._visit_ann_assign(node)
924
+ return [result] if result else []
925
+ elif isinstance(node, ast.Expr):
926
+ result = self._visit_expr_stmt(node)
927
+ return [result] if result else []
928
+ elif isinstance(node, ast.For):
929
+ return self._visit_for(node)
930
+ elif isinstance(node, ast.While):
931
+ return self._visit_while(node)
932
+ elif isinstance(node, ast.If):
933
+ return self._visit_if(node)
934
+ elif isinstance(node, ast.Try):
935
+ return self._visit_try(node)
936
+ elif isinstance(node, ast.Return):
937
+ return self._visit_return(node)
938
+ elif isinstance(node, ast.AugAssign):
939
+ return self._visit_aug_assign(node)
940
+ elif isinstance(node, ast.Pass):
941
+ # Pass statements are fine, they just don't produce IR
942
+ return []
943
+ elif isinstance(node, ast.Break):
944
+ return self._visit_break(node)
945
+ elif isinstance(node, ast.Continue):
946
+ return self._visit_continue(node)
947
+
948
+ # Check for unsupported statement types - this MUST raise for any
949
+ # unhandled statement to avoid silently dropping code
950
+ self._check_unsupported_statement(node)
951
+
952
+ def _check_unsupported_statement(self, node: ast.stmt) -> NoReturn:
953
+ """Check for unsupported statement types and raise descriptive errors.
954
+
955
+ This function ALWAYS raises an exception - it never returns normally.
956
+ Any statement type that reaches this function is either explicitly
957
+ unsupported (with a specific error message) or unhandled (with a
958
+ generic catch-all error). This ensures we never silently drop code.
959
+ """
960
+ line = getattr(node, "lineno", None)
961
+ col = getattr(node, "col_offset", None)
962
+
963
+ if isinstance(node, (ast.With, ast.AsyncWith)):
964
+ raise UnsupportedPatternError(
965
+ "Context managers (with statements) are not supported",
966
+ RECOMMENDATIONS["with_statement"],
967
+ line=line,
968
+ col=col,
969
+ )
970
+ elif isinstance(node, ast.Raise):
971
+ raise UnsupportedPatternError(
972
+ "The 'raise' statement is not supported",
973
+ RECOMMENDATIONS["raise_statement"],
974
+ line=line,
975
+ col=col,
976
+ )
977
+ elif isinstance(node, ast.Assert):
978
+ raise UnsupportedPatternError(
979
+ "Assert statements are not supported",
980
+ RECOMMENDATIONS["assert_statement"],
981
+ line=line,
982
+ col=col,
983
+ )
984
+ elif isinstance(node, ast.Delete):
985
+ raise UnsupportedPatternError(
986
+ "The 'del' statement is not supported",
987
+ RECOMMENDATIONS["delete"],
988
+ line=line,
989
+ col=col,
990
+ )
991
+ elif isinstance(node, ast.Global):
992
+ raise UnsupportedPatternError(
993
+ "Global statements are not supported",
994
+ RECOMMENDATIONS["global_statement"],
995
+ line=line,
996
+ col=col,
997
+ )
998
+ elif isinstance(node, ast.Nonlocal):
999
+ raise UnsupportedPatternError(
1000
+ "Nonlocal statements are not supported",
1001
+ RECOMMENDATIONS["nonlocal_statement"],
1002
+ line=line,
1003
+ col=col,
1004
+ )
1005
+ elif isinstance(node, (ast.Import, ast.ImportFrom)):
1006
+ raise UnsupportedPatternError(
1007
+ "Import statements inside workflow run() are not supported",
1008
+ RECOMMENDATIONS["import_statement"],
1009
+ line=line,
1010
+ col=col,
1011
+ )
1012
+ elif isinstance(node, ast.ClassDef):
1013
+ raise UnsupportedPatternError(
1014
+ "Class definitions inside workflow run() are not supported",
1015
+ RECOMMENDATIONS["class_def"],
1016
+ line=line,
1017
+ col=col,
1018
+ )
1019
+ elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
1020
+ raise UnsupportedPatternError(
1021
+ "Nested function definitions are not supported",
1022
+ RECOMMENDATIONS["function_def"],
1023
+ line=line,
1024
+ col=col,
1025
+ )
1026
+ elif hasattr(ast, "Match") and isinstance(node, ast.Match):
1027
+ raise UnsupportedPatternError(
1028
+ "Match statements are not supported",
1029
+ RECOMMENDATIONS["match"],
1030
+ line=line,
1031
+ col=col,
1032
+ )
1033
+ else:
1034
+ # Catch-all for any unhandled statement types.
1035
+ # This is critical to avoid silently dropping code.
1036
+ stmt_type = type(node).__name__
1037
+ raise UnsupportedPatternError(
1038
+ f"Unhandled statement type: {stmt_type}",
1039
+ RECOMMENDATIONS["unsupported_statement"],
1040
+ line=line,
1041
+ col=col,
1042
+ )
1043
+
1044
+ def _expand_list_comprehension_assignment(
1045
+ self, node: ast.Assign
1046
+ ) -> Optional[List[ir.Statement]]:
1047
+ """Expand a list comprehension assignment into loop-based statements.
1048
+
1049
+ Example:
1050
+ active_users = [user for user in users if user.active]
1051
+
1052
+ Becomes:
1053
+ active_users = []
1054
+ for user in users:
1055
+ if user.active:
1056
+ active_users = active_users + [user]
1057
+ """
1058
+ if not isinstance(node.value, ast.ListComp):
1059
+ return None
1060
+
1061
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
1062
+ line = getattr(node, "lineno", None)
1063
+ col = getattr(node, "col_offset", None)
1064
+ raise UnsupportedPatternError(
1065
+ "List comprehension assignments must target a single variable",
1066
+ "Assign the comprehension to a simple variable like `results = [x for x in items]`",
1067
+ line=line,
1068
+ col=col,
1069
+ )
1070
+
1071
+ listcomp = node.value
1072
+ if len(listcomp.generators) != 1:
1073
+ line = getattr(listcomp, "lineno", None)
1074
+ col = getattr(listcomp, "col_offset", None)
1075
+ raise UnsupportedPatternError(
1076
+ "List comprehensions with multiple generators are not supported",
1077
+ "Use nested for loops instead of combining multiple generators in one comprehension",
1078
+ line=line,
1079
+ col=col,
1080
+ )
1081
+
1082
+ gen = listcomp.generators[0]
1083
+ if gen.is_async:
1084
+ line = getattr(listcomp, "lineno", None)
1085
+ col = getattr(listcomp, "col_offset", None)
1086
+ raise UnsupportedPatternError(
1087
+ "Async list comprehensions are not supported",
1088
+ "Rewrite using an explicit async for loop",
1089
+ line=line,
1090
+ col=col,
1091
+ )
1092
+
1093
+ target_name = node.targets[0].id
1094
+
1095
+ # Initialize the accumulator list: active_users = []
1096
+ init_assign_ast = ast.Assign(
1097
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
1098
+ value=ast.List(elts=[], ctx=ast.Load()),
1099
+ type_comment=None,
1100
+ )
1101
+ ast.copy_location(init_assign_ast, node)
1102
+ ast.fix_missing_locations(init_assign_ast)
1103
+
1104
+ def _make_append_assignment(value_expr: ast.expr) -> ast.Assign:
1105
+ append_assign = ast.Assign(
1106
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
1107
+ value=ast.BinOp(
1108
+ left=ast.Name(id=target_name, ctx=ast.Load()),
1109
+ op=ast.Add(),
1110
+ right=ast.List(elts=[copy.deepcopy(value_expr)], ctx=ast.Load()),
1111
+ ),
1112
+ type_comment=None,
1113
+ )
1114
+ ast.copy_location(append_assign, node.value)
1115
+ ast.fix_missing_locations(append_assign)
1116
+ return append_assign
1117
+
1118
+ append_statements: List[ast.stmt] = []
1119
+ if isinstance(listcomp.elt, ast.IfExp):
1120
+ then_assign = _make_append_assignment(listcomp.elt.body)
1121
+ else_assign = _make_append_assignment(listcomp.elt.orelse)
1122
+ branch_if = ast.If(
1123
+ test=copy.deepcopy(listcomp.elt.test),
1124
+ body=[then_assign],
1125
+ orelse=[else_assign],
1126
+ )
1127
+ ast.copy_location(branch_if, listcomp.elt)
1128
+ ast.fix_missing_locations(branch_if)
1129
+ append_statements.append(branch_if)
1130
+ else:
1131
+ append_statements.append(_make_append_assignment(listcomp.elt))
1132
+
1133
+ loop_body: List[ast.stmt] = append_statements
1134
+ if gen.ifs:
1135
+ condition: ast.expr
1136
+ if len(gen.ifs) == 1:
1137
+ condition = copy.deepcopy(gen.ifs[0])
1138
+ else:
1139
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
1140
+ ast.copy_location(condition, gen.ifs[0])
1141
+ if_stmt = ast.If(test=condition, body=append_statements, orelse=[])
1142
+ ast.copy_location(if_stmt, gen.ifs[0])
1143
+ ast.fix_missing_locations(if_stmt)
1144
+ loop_body = [if_stmt]
1145
+
1146
+ loop_ast = ast.For(
1147
+ target=copy.deepcopy(gen.target),
1148
+ iter=copy.deepcopy(gen.iter),
1149
+ body=loop_body,
1150
+ orelse=[],
1151
+ type_comment=None,
1152
+ )
1153
+ ast.copy_location(loop_ast, node)
1154
+ ast.fix_missing_locations(loop_ast)
1155
+
1156
+ statements: List[ir.Statement] = []
1157
+ init_stmt = self._visit_assign(init_assign_ast)
1158
+ if init_stmt:
1159
+ statements.append(init_stmt)
1160
+ statements.extend(self._visit_for(loop_ast))
1161
+
1162
+ return statements
1163
+
1164
+ def _expand_dict_comprehension_assignment(
1165
+ self, node: ast.Assign
1166
+ ) -> Optional[List[ir.Statement]]:
1167
+ """Expand a dict comprehension assignment into loop-based statements."""
1168
+ if not isinstance(node.value, ast.DictComp):
1169
+ return None
1170
+
1171
+ if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
1172
+ line = getattr(node, "lineno", None)
1173
+ col = getattr(node, "col_offset", None)
1174
+ raise UnsupportedPatternError(
1175
+ "Dict comprehension assignments must target a single variable",
1176
+ "Assign the comprehension to a simple variable like `result = {k: v for k, v in pairs}`",
1177
+ line=line,
1178
+ col=col,
1179
+ )
1180
+
1181
+ dictcomp = node.value
1182
+ if len(dictcomp.generators) != 1:
1183
+ line = getattr(dictcomp, "lineno", None)
1184
+ col = getattr(dictcomp, "col_offset", None)
1185
+ raise UnsupportedPatternError(
1186
+ "Dict comprehensions with multiple generators are not supported",
1187
+ "Use nested for loops instead of combining multiple generators in one comprehension",
1188
+ line=line,
1189
+ col=col,
1190
+ )
1191
+
1192
+ gen = dictcomp.generators[0]
1193
+ if gen.is_async:
1194
+ line = getattr(dictcomp, "lineno", None)
1195
+ col = getattr(dictcomp, "col_offset", None)
1196
+ raise UnsupportedPatternError(
1197
+ "Async dict comprehensions are not supported",
1198
+ "Rewrite using an explicit async for loop",
1199
+ line=line,
1200
+ col=col,
1201
+ )
1202
+
1203
+ target_name = node.targets[0].id
1204
+
1205
+ # Initialize accumulator: result = {}
1206
+ init_assign_ast = ast.Assign(
1207
+ targets=[ast.Name(id=target_name, ctx=ast.Store())],
1208
+ value=ast.Dict(keys=[], values=[]),
1209
+ type_comment=None,
1210
+ )
1211
+ ast.copy_location(init_assign_ast, node)
1212
+ ast.fix_missing_locations(init_assign_ast)
1213
+
1214
+ # result[key] = value
1215
+ subscript_target = ast.Subscript(
1216
+ value=ast.Name(id=target_name, ctx=ast.Load()),
1217
+ slice=copy.deepcopy(dictcomp.key),
1218
+ ctx=ast.Store(),
1219
+ )
1220
+ append_assign_ast = ast.Assign(
1221
+ targets=[subscript_target],
1222
+ value=copy.deepcopy(dictcomp.value),
1223
+ type_comment=None,
1224
+ )
1225
+ ast.copy_location(append_assign_ast, node.value)
1226
+ ast.fix_missing_locations(append_assign_ast)
1227
+
1228
+ loop_body: List[ast.stmt] = []
1229
+ if gen.ifs:
1230
+ condition: ast.expr
1231
+ if len(gen.ifs) == 1:
1232
+ condition = copy.deepcopy(gen.ifs[0])
1233
+ else:
1234
+ condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
1235
+ ast.copy_location(condition, gen.ifs[0])
1236
+ if_stmt = ast.If(test=condition, body=[append_assign_ast], orelse=[])
1237
+ ast.copy_location(if_stmt, gen.ifs[0])
1238
+ ast.fix_missing_locations(if_stmt)
1239
+ loop_body.append(if_stmt)
1240
+ else:
1241
+ loop_body.append(append_assign_ast)
1242
+
1243
+ loop_ast = ast.For(
1244
+ target=copy.deepcopy(gen.target),
1245
+ iter=copy.deepcopy(gen.iter),
1246
+ body=loop_body,
1247
+ orelse=[],
1248
+ type_comment=None,
1249
+ )
1250
+ ast.copy_location(loop_ast, node)
1251
+ ast.fix_missing_locations(loop_ast)
1252
+
1253
+ statements: List[ir.Statement] = []
1254
+ init_stmt = self._visit_assign(init_assign_ast)
1255
+ if init_stmt:
1256
+ statements.append(init_stmt)
1257
+ statements.extend(self._visit_for(loop_ast))
1258
+
1259
+ return statements
1260
+
1261
+ def _visit_ann_assign(self, node: ast.AnnAssign) -> Optional[ir.Statement]:
1262
+ """Convert annotated assignment to IR when a value is present."""
1263
+ if node.value is None:
1264
+ return None
1265
+
1266
+ assign = ast.Assign(targets=[node.target], value=node.value, type_comment=None)
1267
+ ast.copy_location(assign, node)
1268
+ ast.fix_missing_locations(assign)
1269
+ return self._visit_assign(assign)
1270
+
1271
+ def _visit_assign(self, node: ast.Assign) -> Optional[ir.Statement]:
1272
+ """Convert assignment to IR.
1273
+
1274
+ All assignments with targets use the Assignment statement type.
1275
+ This provides uniform unpacking support for:
1276
+ - Action calls: a, b = @get_pair()
1277
+ - Parallel blocks: a, b = parallel: @x() @y()
1278
+ - Regular expressions: a, b = some_list
1279
+
1280
+ Raises UnsupportedPatternError for:
1281
+ - Constructor calls: x = MyClass(...)
1282
+ - Non-action await: x = await some_func()
1283
+ """
1284
+ stmt = ir.Statement(span=_make_span(node))
1285
+ targets = self._get_assign_targets(node.targets)
1286
+
1287
+ # Check for Pydantic model or dataclass constructor calls
1288
+ # These are converted to dict expressions
1289
+ model_name = self._is_model_constructor(node.value)
1290
+ if model_name and isinstance(node.value, ast.Call):
1291
+ value_expr = self._convert_model_constructor_to_dict(node.value, model_name)
1292
+ assign = ir.Assignment(targets=targets, value=value_expr)
1293
+ stmt.assignment.CopyFrom(assign)
1294
+ return stmt
1295
+
1296
+ # Check for constructor calls in assignment (e.g., x = MyModel(...))
1297
+ # This must come AFTER the model constructor check since models are allowed
1298
+ self._check_constructor_in_assignment(node.value)
1299
+
1300
+ # Check for asyncio.gather() - convert to parallel or spread expression
1301
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1302
+ gather_call = node.value.value
1303
+ if self._is_asyncio_gather_call(gather_call):
1304
+ gather_result = self._convert_asyncio_gather(gather_call)
1305
+ if gather_result is not None:
1306
+ if isinstance(gather_result, ir.ParallelExpr):
1307
+ value = ir.Expr(parallel_expr=gather_result, span=_make_span(node))
1308
+ else:
1309
+ # SpreadExpr
1310
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1311
+ assign = ir.Assignment(targets=targets, value=value)
1312
+ stmt.assignment.CopyFrom(assign)
1313
+ return stmt
1314
+
1315
+ # Check if this is an action call - wrap in Assignment for uniform unpacking
1316
+ action_call = self._extract_action_call(node.value)
1317
+ if action_call:
1318
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
1319
+ assign = ir.Assignment(targets=targets, value=value)
1320
+ stmt.assignment.CopyFrom(assign)
1321
+ return stmt
1322
+
1323
+ # Regular assignment (variables, literals, expressions)
1324
+ value_expr = self._expr_to_ir_with_model_coercion(node.value)
1325
+ if value_expr:
1326
+ assign = ir.Assignment(targets=targets, value=value_expr)
1327
+ stmt.assignment.CopyFrom(assign)
1328
+ return stmt
1329
+
1330
+ return None
1331
+
1332
+ def _visit_expr_stmt(self, node: ast.Expr) -> Optional[ir.Statement]:
1333
+ """Convert expression statement to IR (side effect only, no assignment)."""
1334
+ stmt = ir.Statement(span=_make_span(node))
1335
+
1336
+ # Check for asyncio.gather() - convert to parallel block statement (side effect)
1337
+ if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
1338
+ gather_call = node.value.value
1339
+ if self._is_asyncio_gather_call(gather_call):
1340
+ gather_result = self._convert_asyncio_gather(gather_call)
1341
+ if gather_result is not None:
1342
+ if isinstance(gather_result, ir.ParallelExpr):
1343
+ # Side effect only - use ParallelBlock statement
1344
+ parallel = ir.ParallelBlock()
1345
+ parallel.calls.extend(gather_result.calls)
1346
+ stmt.parallel_block.CopyFrom(parallel)
1347
+ return stmt
1348
+ else:
1349
+ # SpreadExpr as side effect - wrap in assignment with no targets
1350
+ # This handles: await asyncio.gather(*[action(x) for x in items])
1351
+ value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
1352
+ assign = ir.Assignment(targets=[], value=value)
1353
+ stmt.assignment.CopyFrom(assign)
1354
+ return stmt
1355
+
1356
+ # Check if this is an action call (side effect only)
1357
+ action_call = self._extract_action_call(node.value)
1358
+ if action_call:
1359
+ stmt.action_call.CopyFrom(action_call)
1360
+ return stmt
1361
+
1362
+ # Convert list.append(x) to list = list + [x]
1363
+ # This makes the mutation explicit so data flows correctly through the DAG
1364
+ if isinstance(node.value, ast.Call):
1365
+ call = node.value
1366
+ if (
1367
+ isinstance(call.func, ast.Attribute)
1368
+ and call.func.attr == "append"
1369
+ and isinstance(call.func.value, ast.Name)
1370
+ and len(call.args) == 1
1371
+ ):
1372
+ list_name = call.func.value.id
1373
+ append_value = call.args[0]
1374
+ # Create: list = list + [value]
1375
+ list_var = ir.Expr(variable=ir.Variable(name=list_name), span=_make_span(node))
1376
+ value_expr = self._expr_to_ir_with_model_coercion(append_value)
1377
+ if value_expr:
1378
+ # Create [value] as a list literal
1379
+ list_literal = ir.Expr(
1380
+ list=ir.ListExpr(elements=[value_expr]), span=_make_span(node)
1381
+ )
1382
+ # Create list + [value]
1383
+ concat_expr = ir.Expr(
1384
+ binary_op=ir.BinaryOp(
1385
+ op=ir.BinaryOperator.BINARY_OP_ADD, left=list_var, right=list_literal
1386
+ ),
1387
+ span=_make_span(node),
1388
+ )
1389
+ assign = ir.Assignment(targets=[list_name], value=concat_expr)
1390
+ stmt.assignment.CopyFrom(assign)
1391
+ return stmt
1392
+
1393
+ # Regular expression
1394
+ expr = self._expr_to_ir_with_model_coercion(node.value)
1395
+ if expr:
1396
+ stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=expr))
1397
+ return stmt
1398
+
1399
+ return None
1400
+
1401
+ def _visit_for(self, node: ast.For) -> List[ir.Statement]:
1402
+ """Convert for loop to IR.
1403
+
1404
+ The loop body is emitted as a full block so it can contain multiple
1405
+ statements/calls and early `return`.
1406
+ """
1407
+ # Get loop variables
1408
+ loop_vars: List[str] = []
1409
+ if isinstance(node.target, ast.Name):
1410
+ loop_vars.append(node.target.id)
1411
+ elif isinstance(node.target, ast.Tuple):
1412
+ for elt in node.target.elts:
1413
+ if isinstance(elt, ast.Name):
1414
+ loop_vars.append(elt.id)
1415
+
1416
+ # Get iterable
1417
+ iterable = self._expr_to_ir_with_model_coercion(node.iter)
1418
+ if not iterable:
1419
+ return []
1420
+
1421
+ # Build body statements (recursively transforms nested structures)
1422
+ body_stmts: List[ir.Statement] = []
1423
+ for body_node in node.body:
1424
+ stmts = self._visit_statement(body_node)
1425
+ body_stmts.extend(stmts)
1426
+
1427
+ stmt = ir.Statement(span=_make_span(node))
1428
+ for_loop = ir.ForLoop(
1429
+ loop_vars=loop_vars,
1430
+ iterable=iterable,
1431
+ block_body=ir.Block(statements=body_stmts, span=_make_span(node)),
1432
+ )
1433
+ stmt.for_loop.CopyFrom(for_loop)
1434
+ return [stmt]
1435
+
1436
+ def _visit_while(self, node: ast.While) -> List[ir.Statement]:
1437
+ """Convert while loop to IR.
1438
+
1439
+ The loop body is emitted as a full block so it can contain multiple
1440
+ statements/calls and early `return`.
1441
+ """
1442
+ condition = self._expr_to_ir_with_model_coercion(node.test)
1443
+ if not condition:
1444
+ return []
1445
+
1446
+ body_stmts: List[ir.Statement] = []
1447
+ for body_node in node.body:
1448
+ stmts = self._visit_statement(body_node)
1449
+ body_stmts.extend(stmts)
1450
+
1451
+ stmt = ir.Statement(span=_make_span(node))
1452
+ while_loop = ir.WhileLoop(
1453
+ condition=condition,
1454
+ block_body=ir.Block(statements=body_stmts, span=_make_span(node)),
1455
+ )
1456
+ stmt.while_loop.CopyFrom(while_loop)
1457
+ return [stmt]
1458
+
1459
+ def _detect_accumulator_targets(
1460
+ self, stmts: List[ir.Statement], in_scope_vars: set
1461
+ ) -> List[str]:
1462
+ """Detect out-of-scope variable modifications in for loop body.
1463
+
1464
+ Scans statements for patterns that modify variables defined outside the loop.
1465
+ Returns a list of accumulator variable names that should be set as targets.
1466
+
1467
+ Supported patterns:
1468
+ 1. List append: results.append(value) -> "results"
1469
+ 2. Dict subscript: result[key] = value -> "result"
1470
+ 3. List/set update methods: results.extend(...), results.add(...) -> "results"
1471
+
1472
+ Note: Patterns like `results = results + [x]` and `count = count + 1` create
1473
+ new assignments which are tracked via in_scope_vars and don't need special
1474
+ detection here - they're handled by the regular assignment target logic.
1475
+ """
1476
+ accumulators: List[str] = []
1477
+ seen: set = set()
1478
+
1479
+ for stmt in stmts:
1480
+ var_name = self._extract_accumulator_from_stmt(stmt, in_scope_vars)
1481
+ if var_name and var_name not in seen:
1482
+ accumulators.append(var_name)
1483
+ seen.add(var_name)
1484
+
1485
+ # Check conditionals for accumulator targets in branch bodies
1486
+ if stmt.HasField("conditional"):
1487
+ cond = stmt.conditional
1488
+ branch_blocks: list[ir.Block] = []
1489
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1490
+ branch_blocks.append(cond.if_branch.block_body)
1491
+ for branch in cond.elif_branches:
1492
+ if branch.HasField("block_body"):
1493
+ branch_blocks.append(branch.block_body)
1494
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1495
+ branch_blocks.append(cond.else_branch.block_body)
1496
+
1497
+ for block in branch_blocks:
1498
+ for var in self._detect_accumulator_targets(
1499
+ list(block.statements), in_scope_vars
1500
+ ):
1501
+ if var not in seen:
1502
+ accumulators.append(var)
1503
+ seen.add(var)
1504
+
1505
+ return accumulators
1506
+
1507
+ def _extract_accumulator_from_stmt(
1508
+ self, stmt: ir.Statement, in_scope_vars: set
1509
+ ) -> Optional[str]:
1510
+ """Extract accumulator variable name from a single statement.
1511
+
1512
+ Returns the variable name if this statement modifies an out-of-scope variable,
1513
+ None otherwise.
1514
+ """
1515
+ # Pattern 1: Method calls like list.append(), dict.update(), set.add()
1516
+ if stmt.HasField("expr_stmt"):
1517
+ expr = stmt.expr_stmt.expr
1518
+ if expr.HasField("function_call"):
1519
+ fn_name = expr.function_call.name
1520
+ # Check for mutating method calls: x.append, x.extend, x.add, x.update, etc.
1521
+ mutating_methods = {
1522
+ ".append",
1523
+ ".extend",
1524
+ ".add",
1525
+ ".update",
1526
+ ".insert",
1527
+ ".pop",
1528
+ ".remove",
1529
+ ".clear",
1530
+ }
1531
+ for method in mutating_methods:
1532
+ if fn_name.endswith(method):
1533
+ var_name = fn_name[: len(fn_name) - len(method)]
1534
+ # Only return if it's an out-of-scope variable
1535
+ if var_name and var_name not in in_scope_vars:
1536
+ return var_name
1537
+
1538
+ # Pattern 2: Subscript assignment like dict[key] = value
1539
+ if stmt.HasField("assignment"):
1540
+ for target in stmt.assignment.targets:
1541
+ # Check if target is a subscript pattern (contains '[')
1542
+ if "[" in target:
1543
+ # Extract base variable name (before '[')
1544
+ var_name = target.split("[")[0]
1545
+ if var_name and var_name not in in_scope_vars:
1546
+ return var_name
1547
+
1548
+ # Pattern 3: Self-referential assignment like x = x + [y]
1549
+ # The target variable is used on the RHS, so it must come from outside.
1550
+ # Note: We don't check in_scope_vars here because the assignment itself
1551
+ # would have added the target to in_scope_vars, but it still needs its
1552
+ # previous value from outside the loop body.
1553
+ if stmt.HasField("assignment"):
1554
+ assign = stmt.assignment
1555
+ rhs_vars = self._collect_variables_from_expr(assign.value)
1556
+ for target in assign.targets:
1557
+ if target in rhs_vars:
1558
+ return target
1559
+
1560
+ return None
1561
+
1562
+ def _collect_variables_from_expr(self, expr: ir.Expr) -> set:
1563
+ """Recursively collect all variable names used in an expression."""
1564
+ vars_found: set = set()
1565
+
1566
+ if expr.HasField("variable"):
1567
+ vars_found.add(expr.variable.name)
1568
+ elif expr.HasField("binary_op"):
1569
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.left))
1570
+ vars_found.update(self._collect_variables_from_expr(expr.binary_op.right))
1571
+ elif expr.HasField("unary_op"):
1572
+ vars_found.update(self._collect_variables_from_expr(expr.unary_op.operand))
1573
+ elif expr.HasField("list"):
1574
+ for elem in expr.list.elements:
1575
+ vars_found.update(self._collect_variables_from_expr(elem))
1576
+ elif expr.HasField("dict"):
1577
+ for entry in expr.dict.entries:
1578
+ vars_found.update(self._collect_variables_from_expr(entry.key))
1579
+ vars_found.update(self._collect_variables_from_expr(entry.value))
1580
+ elif expr.HasField("index"):
1581
+ vars_found.update(self._collect_variables_from_expr(expr.index.value))
1582
+ vars_found.update(self._collect_variables_from_expr(expr.index.index))
1583
+ elif expr.HasField("dot"):
1584
+ vars_found.update(self._collect_variables_from_expr(expr.dot.object))
1585
+ elif expr.HasField("function_call"):
1586
+ for kwarg in expr.function_call.kwargs:
1587
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1588
+ elif expr.HasField("action_call"):
1589
+ for kwarg in expr.action_call.kwargs:
1590
+ vars_found.update(self._collect_variables_from_expr(kwarg.value))
1591
+
1592
+ return vars_found
1593
+
1594
+ def _visit_if(self, node: ast.If) -> List[ir.Statement]:
1595
+ """Convert if statement to IR.
1596
+
1597
+ Normalizes patterns like:
1598
+ if await some_action(...):
1599
+ ...
1600
+ into:
1601
+ __if_cond_n__ = await some_action(...)
1602
+ if __if_cond_n__:
1603
+ ...
1604
+ """
1605
+
1606
+ def normalize_condition(test: ast.expr) -> tuple[List[ir.Statement], Optional[ir.Expr]]:
1607
+ action_call = self._extract_action_call(test)
1608
+ if action_call is None:
1609
+ return ([], self._expr_to_ir_with_model_coercion(test))
1610
+
1611
+ if not isinstance(test, ast.Await):
1612
+ line = getattr(test, "lineno", None)
1613
+ col = getattr(test, "col_offset", None)
1614
+ raise UnsupportedPatternError(
1615
+ "Action calls inside boolean expressions are not supported in if conditions",
1616
+ "Assign the awaited action result to a variable, then use the variable in the if condition.",
1617
+ line=line,
1618
+ col=col,
1619
+ )
1620
+
1621
+ cond_var = self._ctx.next_implicit_fn_name(prefix="if_cond")
1622
+ assign_stmt = ir.Statement(span=_make_span(test))
1623
+ assign_stmt.assignment.CopyFrom(
1624
+ ir.Assignment(
1625
+ targets=[cond_var],
1626
+ value=ir.Expr(action_call=action_call, span=_make_span(test)),
1627
+ )
1628
+ )
1629
+ cond_expr = ir.Expr(variable=ir.Variable(name=cond_var), span=_make_span(test))
1630
+ return ([assign_stmt], cond_expr)
1631
+
1632
+ def visit_body(nodes: list[ast.stmt]) -> List[ir.Statement]:
1633
+ stmts: List[ir.Statement] = []
1634
+ for body_node in nodes:
1635
+ stmts.extend(self._visit_statement(body_node))
1636
+ return stmts
1637
+
1638
+ # Collect if/elif branches as (test_expr, body_nodes)
1639
+ branches: list[tuple[ast.expr, list[ast.stmt], ast.AST]] = [(node.test, node.body, node)]
1640
+ current = node
1641
+ while current.orelse and len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If):
1642
+ elif_node = current.orelse[0]
1643
+ branches.append((elif_node.test, elif_node.body, elif_node))
1644
+ current = elif_node
1645
+
1646
+ else_nodes = current.orelse
1647
+
1648
+ normalized: list[
1649
+ tuple[List[ir.Statement], Optional[ir.Expr], List[ir.Statement], ast.AST]
1650
+ ] = []
1651
+ for test_expr, body_nodes, span_node in branches:
1652
+ prefix, cond = normalize_condition(test_expr)
1653
+ normalized.append((prefix, cond, visit_body(body_nodes), span_node))
1654
+
1655
+ else_body = visit_body(else_nodes) if else_nodes else []
1656
+
1657
+ # If any non-first branch needs normalization, preserve Python semantics by nesting.
1658
+ requires_nested = any(prefix for prefix, _, _, _ in normalized[1:])
1659
+
1660
+ def build_conditional_stmt(
1661
+ condition: ir.Expr,
1662
+ then_body: List[ir.Statement],
1663
+ else_body_statements: List[ir.Statement],
1664
+ span_node: ast.AST,
1665
+ ) -> ir.Statement:
1666
+ conditional_stmt = ir.Statement(span=_make_span(span_node))
1667
+ if_branch = ir.IfBranch(
1668
+ condition=condition,
1669
+ block_body=ir.Block(statements=then_body, span=_make_span(span_node)),
1670
+ span=_make_span(span_node),
1671
+ )
1672
+ conditional = ir.Conditional(if_branch=if_branch)
1673
+ if else_body_statements:
1674
+ else_branch = ir.ElseBranch(
1675
+ block_body=ir.Block(
1676
+ statements=else_body_statements,
1677
+ span=_make_span(span_node),
1678
+ ),
1679
+ span=_make_span(span_node),
1680
+ )
1681
+ conditional.else_branch.CopyFrom(else_branch)
1682
+ conditional_stmt.conditional.CopyFrom(conditional)
1683
+ return conditional_stmt
1684
+
1685
+ if requires_nested:
1686
+ nested_else: List[ir.Statement] = else_body
1687
+ for prefix, cond, then_body, span_node in reversed(normalized):
1688
+ if cond is None:
1689
+ continue
1690
+ nested_if_stmt = build_conditional_stmt(
1691
+ condition=cond,
1692
+ then_body=then_body,
1693
+ else_body_statements=nested_else,
1694
+ span_node=span_node,
1695
+ )
1696
+ nested_else = [*prefix, nested_if_stmt]
1697
+ return nested_else
1698
+
1699
+ # Flat conditional with elif/else (original behavior), plus optional prefix for the if guard.
1700
+ if_prefix, if_condition, if_body, if_span_node = normalized[0]
1701
+ if if_condition is None:
1702
+ return []
1703
+
1704
+ conditional_stmt = ir.Statement(span=_make_span(if_span_node))
1705
+ if_branch = ir.IfBranch(
1706
+ condition=if_condition,
1707
+ block_body=ir.Block(statements=if_body, span=_make_span(if_span_node)),
1708
+ span=_make_span(if_span_node),
1709
+ )
1710
+ conditional = ir.Conditional(if_branch=if_branch)
1711
+
1712
+ for _, elif_condition, elif_body, elif_span_node in normalized[1:]:
1713
+ if elif_condition is None:
1714
+ continue
1715
+ elif_branch = ir.ElifBranch(
1716
+ condition=elif_condition,
1717
+ block_body=ir.Block(statements=elif_body, span=_make_span(elif_span_node)),
1718
+ span=_make_span(elif_span_node),
1719
+ )
1720
+ conditional.elif_branches.append(elif_branch)
1721
+
1722
+ if else_body:
1723
+ else_branch = ir.ElseBranch(
1724
+ block_body=ir.Block(statements=else_body, span=_make_span(if_span_node)),
1725
+ span=_make_span(if_span_node),
1726
+ )
1727
+ conditional.else_branch.CopyFrom(else_branch)
1728
+
1729
+ conditional_stmt.conditional.CopyFrom(conditional)
1730
+ return [*if_prefix, conditional_stmt]
1731
+
1732
+ def _collect_assigned_vars(self, stmts: List[ir.Statement]) -> set:
1733
+ """Collect all variable names assigned in a list of statements."""
1734
+ assigned = set()
1735
+ for stmt in stmts:
1736
+ if stmt.HasField("assignment"):
1737
+ assigned.update(stmt.assignment.targets)
1738
+ return assigned
1739
+
1740
+ def _collect_assigned_vars_in_order(self, stmts: List[ir.Statement]) -> list[str]:
1741
+ """Collect assigned variable names in statement order (deduplicated)."""
1742
+ assigned: list[str] = []
1743
+ seen: set[str] = set()
1744
+
1745
+ for stmt in stmts:
1746
+ if stmt.HasField("assignment"):
1747
+ for target in stmt.assignment.targets:
1748
+ if target not in seen:
1749
+ seen.add(target)
1750
+ assigned.append(target)
1751
+
1752
+ if stmt.HasField("conditional"):
1753
+ cond = stmt.conditional
1754
+ if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
1755
+ for target in self._collect_assigned_vars_in_order(
1756
+ list(cond.if_branch.block_body.statements)
1757
+ ):
1758
+ if target not in seen:
1759
+ seen.add(target)
1760
+ assigned.append(target)
1761
+ for elif_branch in cond.elif_branches:
1762
+ if elif_branch.HasField("block_body"):
1763
+ for target in self._collect_assigned_vars_in_order(
1764
+ list(elif_branch.block_body.statements)
1765
+ ):
1766
+ if target not in seen:
1767
+ seen.add(target)
1768
+ assigned.append(target)
1769
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1770
+ for target in self._collect_assigned_vars_in_order(
1771
+ list(cond.else_branch.block_body.statements)
1772
+ ):
1773
+ if target not in seen:
1774
+ seen.add(target)
1775
+ assigned.append(target)
1776
+
1777
+ if stmt.HasField("for_loop") and stmt.for_loop.HasField("block_body"):
1778
+ for target in self._collect_assigned_vars_in_order(
1779
+ list(stmt.for_loop.block_body.statements)
1780
+ ):
1781
+ if target not in seen:
1782
+ seen.add(target)
1783
+ assigned.append(target)
1784
+
1785
+ if stmt.HasField("while_loop") and stmt.while_loop.HasField("block_body"):
1786
+ for target in self._collect_assigned_vars_in_order(
1787
+ list(stmt.while_loop.block_body.statements)
1788
+ ):
1789
+ if target not in seen:
1790
+ seen.add(target)
1791
+ assigned.append(target)
1792
+
1793
+ if stmt.HasField("try_except"):
1794
+ try_block = stmt.try_except.try_block
1795
+ if try_block.HasField("span"):
1796
+ for target in self._collect_assigned_vars_in_order(list(try_block.statements)):
1797
+ if target not in seen:
1798
+ seen.add(target)
1799
+ assigned.append(target)
1800
+ for handler in stmt.try_except.handlers:
1801
+ if handler.HasField("block_body"):
1802
+ for target in self._collect_assigned_vars_in_order(
1803
+ list(handler.block_body.statements)
1804
+ ):
1805
+ if target not in seen:
1806
+ seen.add(target)
1807
+ assigned.append(target)
1808
+
1809
+ return assigned
1810
+
1811
+ def _collect_variables_from_block(self, block: ir.Block) -> list[str]:
1812
+ return self._collect_variables_from_statements(list(block.statements))
1813
+
1814
+ def _collect_variables_from_statements(self, stmts: List[ir.Statement]) -> list[str]:
1815
+ """Collect variable references from statements in encounter order."""
1816
+ vars_found: list[str] = []
1817
+ seen: set[str] = set()
1818
+
1819
+ for stmt in stmts:
1820
+ if stmt.HasField("assignment") and stmt.assignment.HasField("value"):
1821
+ for var in self._collect_variables_from_expr(stmt.assignment.value):
1822
+ if var not in seen:
1823
+ seen.add(var)
1824
+ vars_found.append(var)
1825
+
1826
+ if stmt.HasField("return_stmt") and stmt.return_stmt.HasField("value"):
1827
+ for var in self._collect_variables_from_expr(stmt.return_stmt.value):
1828
+ if var not in seen:
1829
+ seen.add(var)
1830
+ vars_found.append(var)
1831
+
1832
+ if stmt.HasField("action_call"):
1833
+ expr = ir.Expr(action_call=stmt.action_call, span=stmt.span)
1834
+ for var in self._collect_variables_from_expr(expr):
1835
+ if var not in seen:
1836
+ seen.add(var)
1837
+ vars_found.append(var)
1838
+
1839
+ if stmt.HasField("expr_stmt"):
1840
+ for var in self._collect_variables_from_expr(stmt.expr_stmt.expr):
1841
+ if var not in seen:
1842
+ seen.add(var)
1843
+ vars_found.append(var)
1844
+
1845
+ if stmt.HasField("conditional"):
1846
+ cond = stmt.conditional
1847
+ if cond.HasField("if_branch"):
1848
+ if cond.if_branch.HasField("condition"):
1849
+ for var in self._collect_variables_from_expr(cond.if_branch.condition):
1850
+ if var not in seen:
1851
+ seen.add(var)
1852
+ vars_found.append(var)
1853
+ if cond.if_branch.HasField("block_body"):
1854
+ for var in self._collect_variables_from_block(cond.if_branch.block_body):
1855
+ if var not in seen:
1856
+ seen.add(var)
1857
+ vars_found.append(var)
1858
+ for elif_branch in cond.elif_branches:
1859
+ if elif_branch.HasField("condition"):
1860
+ for var in self._collect_variables_from_expr(elif_branch.condition):
1861
+ if var not in seen:
1862
+ seen.add(var)
1863
+ vars_found.append(var)
1864
+ if elif_branch.HasField("block_body"):
1865
+ for var in self._collect_variables_from_block(elif_branch.block_body):
1866
+ if var not in seen:
1867
+ seen.add(var)
1868
+ vars_found.append(var)
1869
+ if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
1870
+ for var in self._collect_variables_from_block(cond.else_branch.block_body):
1871
+ if var not in seen:
1872
+ seen.add(var)
1873
+ vars_found.append(var)
1874
+
1875
+ if stmt.HasField("for_loop"):
1876
+ fl = stmt.for_loop
1877
+ if fl.HasField("iterable"):
1878
+ for var in self._collect_variables_from_expr(fl.iterable):
1879
+ if var not in seen:
1880
+ seen.add(var)
1881
+ vars_found.append(var)
1882
+ if fl.HasField("block_body"):
1883
+ for var in self._collect_variables_from_block(fl.block_body):
1884
+ if var not in seen:
1885
+ seen.add(var)
1886
+ vars_found.append(var)
1887
+
1888
+ if stmt.HasField("while_loop"):
1889
+ wl = stmt.while_loop
1890
+ if wl.HasField("condition"):
1891
+ for var in self._collect_variables_from_expr(wl.condition):
1892
+ if var not in seen:
1893
+ seen.add(var)
1894
+ vars_found.append(var)
1895
+ if wl.HasField("block_body"):
1896
+ for var in self._collect_variables_from_block(wl.block_body):
1897
+ if var not in seen:
1898
+ seen.add(var)
1899
+ vars_found.append(var)
1900
+
1901
+ if stmt.HasField("try_except"):
1902
+ te = stmt.try_except
1903
+ if te.HasField("try_block"):
1904
+ for var in self._collect_variables_from_block(te.try_block):
1905
+ if var not in seen:
1906
+ seen.add(var)
1907
+ vars_found.append(var)
1908
+ for handler in te.handlers:
1909
+ if handler.HasField("block_body"):
1910
+ for var in self._collect_variables_from_block(handler.block_body):
1911
+ if var not in seen:
1912
+ seen.add(var)
1913
+ vars_found.append(var)
1914
+
1915
+ if stmt.HasField("parallel_block"):
1916
+ for call in stmt.parallel_block.calls:
1917
+ if call.HasField("action"):
1918
+ for kwarg in call.action.kwargs:
1919
+ for var in self._collect_variables_from_expr(kwarg.value):
1920
+ if var not in seen:
1921
+ seen.add(var)
1922
+ vars_found.append(var)
1923
+ elif call.HasField("function"):
1924
+ for kwarg in call.function.kwargs:
1925
+ for var in self._collect_variables_from_expr(kwarg.value):
1926
+ if var not in seen:
1927
+ seen.add(var)
1928
+ vars_found.append(var)
1929
+
1930
+ if stmt.HasField("spread_action"):
1931
+ spread = stmt.spread_action
1932
+ if spread.HasField("collection"):
1933
+ for var in self._collect_variables_from_expr(spread.collection):
1934
+ if var not in seen:
1935
+ seen.add(var)
1936
+ vars_found.append(var)
1937
+ if spread.HasField("action"):
1938
+ for kwarg in spread.action.kwargs:
1939
+ for var in self._collect_variables_from_expr(kwarg.value):
1940
+ if var not in seen:
1941
+ seen.add(var)
1942
+ vars_found.append(var)
1943
+
1944
+ return vars_found
1945
+
1946
+ def _visit_try(self, node: ast.Try) -> List[ir.Statement]:
1947
+ """Convert try/except to IR with full block bodies."""
1948
+ # Build try body statements (recursively transforms nested structures)
1949
+ try_body: List[ir.Statement] = []
1950
+ for body_node in node.body:
1951
+ stmts = self._visit_statement(body_node)
1952
+ try_body.extend(stmts)
1953
+
1954
+ # Build exception handlers (with wrapping if needed)
1955
+ handlers: List[ir.ExceptHandler] = []
1956
+ for handler in node.handlers:
1957
+ exception_types: List[str] = []
1958
+ if handler.type:
1959
+ if isinstance(handler.type, ast.Name):
1960
+ exception_types.append(handler.type.id)
1961
+ elif isinstance(handler.type, ast.Tuple):
1962
+ for elt in handler.type.elts:
1963
+ if isinstance(elt, ast.Name):
1964
+ exception_types.append(elt.id)
1965
+
1966
+ # Build handler body (recursively transforms nested structures)
1967
+ handler_body: List[ir.Statement] = []
1968
+ for handler_node in handler.body:
1969
+ stmts = self._visit_statement(handler_node)
1970
+ handler_body.extend(stmts)
1971
+
1972
+ except_handler = ir.ExceptHandler(
1973
+ exception_types=exception_types,
1974
+ block_body=ir.Block(statements=handler_body, span=_make_span(handler)),
1975
+ span=_make_span(handler),
1976
+ )
1977
+ if handler.name:
1978
+ except_handler.exception_var = handler.name
1979
+ handlers.append(except_handler)
1980
+
1981
+ # Build the try/except statement
1982
+ try_stmt = ir.Statement(span=_make_span(node))
1983
+ try_except = ir.TryExcept(
1984
+ handlers=handlers,
1985
+ try_block=ir.Block(statements=try_body, span=_make_span(node)),
1986
+ )
1987
+ try_stmt.try_except.CopyFrom(try_except)
1988
+
1989
+ return [try_stmt]
1990
+
1991
+ def _count_calls(self, stmts: List[ir.Statement]) -> int:
1992
+ """Count action calls and function calls in statements.
1993
+
1994
+ Both action calls and function calls (including synthetic functions)
1995
+ count toward the limit of one call per control flow body.
1996
+ """
1997
+ count = 0
1998
+ for stmt in stmts:
1999
+ if stmt.HasField("action_call"):
2000
+ count += 1
2001
+ elif stmt.HasField("assignment"):
2002
+ # Check if assignment value is an action call or function call
2003
+ if stmt.assignment.value.HasField("action_call"):
2004
+ count += 1
2005
+ elif stmt.assignment.value.HasField("function_call"):
2006
+ count += 1
2007
+ elif stmt.HasField("expr_stmt"):
2008
+ # Check if expression is a function call
2009
+ if stmt.expr_stmt.expr.HasField("function_call"):
2010
+ count += 1
2011
+ return count
2012
+
2013
+ def _wrap_body_as_function(
2014
+ self,
2015
+ body: List[ir.Statement],
2016
+ prefix: str,
2017
+ node: ast.AST,
2018
+ inputs: Optional[List[str]] = None,
2019
+ modified_vars: Optional[List[str]] = None,
2020
+ ) -> List[ir.Statement]:
2021
+ """Wrap a body with multiple calls into a synthetic function.
2022
+
2023
+ Args:
2024
+ body: The statements to wrap
2025
+ prefix: Name prefix for the synthetic function
2026
+ node: AST node for span information
2027
+ inputs: Variables to pass as inputs (e.g., loop variables)
2028
+ modified_vars: Out-of-scope variables modified in the body.
2029
+ These are added as inputs AND returned as outputs,
2030
+ enabling functional transformation of external state.
2031
+
2032
+ Returns a list containing a single function call statement (or assignment
2033
+ if modified_vars are present).
2034
+ """
2035
+ fn_name = self._ctx.next_implicit_fn_name(prefix)
2036
+ fn_inputs = list(inputs or [])
2037
+
2038
+ # Add modified variables as inputs (they need to be passed in)
2039
+ modified_vars = modified_vars or []
2040
+ for var in modified_vars:
2041
+ if var not in fn_inputs:
2042
+ fn_inputs.append(var)
2043
+
2044
+ # If there are modified variables, add a return statement for them
2045
+ wrapped_body = list(body)
2046
+ if modified_vars:
2047
+ # Create return statement: return (var1, var2, ...) or return var1
2048
+ if len(modified_vars) == 1:
2049
+ return_expr = ir.Expr(
2050
+ variable=ir.Variable(name=modified_vars[0]),
2051
+ span=_make_span(node),
2052
+ )
2053
+ else:
2054
+ # Return as list (tuples are represented as lists in IR)
2055
+ return_expr = ir.Expr(
2056
+ list=ir.ListExpr(
2057
+ elements=[ir.Expr(variable=ir.Variable(name=var)) for var in modified_vars]
2058
+ ),
2059
+ span=_make_span(node),
2060
+ )
2061
+ return_stmt = ir.Statement(span=_make_span(node))
2062
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=return_expr))
2063
+ wrapped_body.append(return_stmt)
2064
+
2065
+ # Create the synthetic function
2066
+ implicit_fn = ir.FunctionDef(
2067
+ name=fn_name,
2068
+ io=ir.IoDecl(inputs=fn_inputs, outputs=modified_vars),
2069
+ body=ir.Block(statements=wrapped_body),
2070
+ span=_make_span(node),
2071
+ )
2072
+ self._ctx.implicit_functions.append(implicit_fn)
2073
+
2074
+ # Create a function call expression
2075
+ kwargs = [
2076
+ ir.Kwarg(name=var, value=ir.Expr(variable=ir.Variable(name=var))) for var in fn_inputs
2077
+ ]
2078
+ fn_call_expr = ir.Expr(
2079
+ function_call=ir.FunctionCall(name=fn_name, kwargs=kwargs),
2080
+ span=_make_span(node),
2081
+ )
2082
+
2083
+ # If there are modified variables, create an assignment statement
2084
+ # so the returned values are assigned back to the variables
2085
+ call_stmt = ir.Statement(span=_make_span(node))
2086
+ if modified_vars:
2087
+ # Create assignment: var1, var2 = fn(...) or var1 = fn(...)
2088
+ assign = ir.Assignment(value=fn_call_expr)
2089
+ assign.targets.extend(modified_vars)
2090
+ call_stmt.assignment.CopyFrom(assign)
2091
+ else:
2092
+ call_stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=fn_call_expr))
2093
+
2094
+ return [call_stmt]
2095
+
2096
+ def _visit_return(self, node: ast.Return) -> List[ir.Statement]:
2097
+ """Convert return statement to IR.
2098
+
2099
+ Return statements should only contain variables or literals, not action calls.
2100
+ If the return contains an action call, we normalize it:
2101
+ return await action()
2102
+ becomes:
2103
+ _return_tmp = await action()
2104
+ return _return_tmp
2105
+
2106
+ Constructor calls (like return MyModel(...)) are not supported and will
2107
+ raise an error with a recommendation to use an @action instead.
2108
+ """
2109
+ if node.value:
2110
+ # Check for constructor calls in return (e.g., return MyModel(...))
2111
+ self._check_constructor_in_return(node.value)
2112
+
2113
+ # Check if returning an action call - normalize to assignment + return
2114
+ action_call = self._extract_action_call(node.value)
2115
+ if action_call:
2116
+ # Create a temporary variable for the action result
2117
+ tmp_var = "_return_tmp"
2118
+
2119
+ # Create assignment: _return_tmp = await action()
2120
+ assign_stmt = ir.Statement(span=_make_span(node))
2121
+ value = ir.Expr(action_call=action_call, span=_make_span(node))
2122
+ assign = ir.Assignment(targets=[tmp_var], value=value)
2123
+ assign_stmt.assignment.CopyFrom(assign)
2124
+
2125
+ # Create return: return _return_tmp
2126
+ return_stmt = ir.Statement(span=_make_span(node))
2127
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
2128
+ ret = ir.ReturnStmt(value=var_expr)
2129
+ return_stmt.return_stmt.CopyFrom(ret)
2130
+
2131
+ return [assign_stmt, return_stmt]
2132
+
2133
+ # Normalize return of function calls into assignment + return
2134
+ expr = self._expr_to_ir_with_model_coercion(node.value)
2135
+ if expr and expr.HasField("function_call"):
2136
+ tmp_var = self._ctx.next_implicit_fn_name(prefix="return_tmp")
2137
+
2138
+ assign_stmt = ir.Statement(span=_make_span(node))
2139
+ assign_stmt.assignment.CopyFrom(ir.Assignment(targets=[tmp_var], value=expr))
2140
+
2141
+ return_stmt = ir.Statement(span=_make_span(node))
2142
+ var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
2143
+ return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=var_expr))
2144
+ return [assign_stmt, return_stmt]
2145
+
2146
+ # Regular return with expression (variable, literal, etc.)
2147
+ if expr:
2148
+ stmt = ir.Statement(span=_make_span(node))
2149
+ return_stmt = ir.ReturnStmt(value=expr)
2150
+ stmt.return_stmt.CopyFrom(return_stmt)
2151
+ return [stmt]
2152
+
2153
+ # Return with no value
2154
+ stmt = ir.Statement(span=_make_span(node))
2155
+ stmt.return_stmt.CopyFrom(ir.ReturnStmt())
2156
+ return [stmt]
2157
+
2158
+ def _visit_break(self, node: ast.Break) -> List[ir.Statement]:
2159
+ """Convert break statement to IR."""
2160
+ stmt = ir.Statement(span=_make_span(node))
2161
+ stmt.break_stmt.CopyFrom(ir.BreakStmt())
2162
+ return [stmt]
2163
+
2164
+ def _visit_continue(self, node: ast.Continue) -> List[ir.Statement]:
2165
+ """Convert continue statement to IR."""
2166
+ stmt = ir.Statement(span=_make_span(node))
2167
+ stmt.continue_stmt.CopyFrom(ir.ContinueStmt())
2168
+ return [stmt]
2169
+
2170
+ def _visit_aug_assign(self, node: ast.AugAssign) -> List[ir.Statement]:
2171
+ """Convert augmented assignment (+=, -=, etc.) to IR."""
2172
+ # For now, we can represent this as a regular assignment with binary op
2173
+ # target op= value -> target = target op value
2174
+ stmt = ir.Statement(span=_make_span(node))
2175
+
2176
+ targets: List[str] = []
2177
+ if isinstance(node.target, ast.Name):
2178
+ targets.append(node.target.id)
2179
+
2180
+ left = self._expr_to_ir_with_model_coercion(node.target)
2181
+ right = self._expr_to_ir_with_model_coercion(node.value)
2182
+ if right and right.HasField("function_call"):
2183
+ tmp_var = self._ctx.next_implicit_fn_name(prefix="aug_tmp")
2184
+
2185
+ assign_tmp = ir.Statement(span=_make_span(node))
2186
+ assign_tmp.assignment.CopyFrom(
2187
+ ir.Assignment(
2188
+ targets=[tmp_var],
2189
+ value=ir.Expr(function_call=right.function_call, span=_make_span(node)),
2190
+ )
2191
+ )
2192
+
2193
+ if left:
2194
+ op = _bin_op_to_ir(node.op)
2195
+ if op:
2196
+ binary = ir.BinaryOp(
2197
+ left=left,
2198
+ op=op,
2199
+ right=ir.Expr(variable=ir.Variable(name=tmp_var)),
2200
+ )
2201
+ value = ir.Expr(binary_op=binary)
2202
+ assign = ir.Assignment(targets=targets, value=value)
2203
+ stmt.assignment.CopyFrom(assign)
2204
+ return [assign_tmp, stmt]
2205
+ return [assign_tmp]
2206
+
2207
+ if left and right:
2208
+ op = _bin_op_to_ir(node.op)
2209
+ if op:
2210
+ binary = ir.BinaryOp(left=left, op=op, right=right)
2211
+ value = ir.Expr(binary_op=binary)
2212
+ assign = ir.Assignment(targets=targets, value=value)
2213
+ stmt.assignment.CopyFrom(assign)
2214
+ return [stmt]
2215
+
2216
+ return []
2217
+
2218
+ def _collect_function_inputs(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> List[str]:
2219
+ """Collect workflow inputs from function parameters, including kw-only args."""
2220
+ args: List[str] = []
2221
+ seen: set[str] = set()
2222
+
2223
+ ordered_args = list(node.args.posonlyargs) + list(node.args.args)
2224
+ if ordered_args and ordered_args[0].arg == "self":
2225
+ ordered_args = ordered_args[1:]
2226
+
2227
+ for arg in ordered_args:
2228
+ if arg.arg not in seen:
2229
+ args.append(arg.arg)
2230
+ seen.add(arg.arg)
2231
+
2232
+ if node.args.vararg and node.args.vararg.arg not in seen:
2233
+ args.append(node.args.vararg.arg)
2234
+ seen.add(node.args.vararg.arg)
2235
+
2236
+ for arg in node.args.kwonlyargs:
2237
+ if arg.arg not in seen:
2238
+ args.append(arg.arg)
2239
+ seen.add(arg.arg)
2240
+
2241
+ if node.args.kwarg and node.args.kwarg.arg not in seen:
2242
+ args.append(node.args.kwarg.arg)
2243
+ seen.add(node.args.kwarg.arg)
2244
+
2245
+ return args
2246
+
2247
+ def _check_constructor_in_return(self, node: ast.expr) -> None:
2248
+ """Check for constructor calls in return statements.
2249
+
2250
+ Raises UnsupportedPatternError if the return value is a class instantiation
2251
+ like: return MyModel(field=value)
2252
+
2253
+ This is not supported because the workflow IR cannot serialize arbitrary
2254
+ object instantiation. Users should use an @action to create objects.
2255
+ """
2256
+ # Skip if it's an await (action call) - those are fine
2257
+ if isinstance(node, ast.Await):
2258
+ return
2259
+
2260
+ # Check for direct Call that looks like a constructor
2261
+ if isinstance(node, ast.Call):
2262
+ func_name = self._get_constructor_name(node.func)
2263
+ if func_name and self._looks_like_constructor(func_name, node):
2264
+ line = getattr(node, "lineno", None)
2265
+ col = getattr(node, "col_offset", None)
2266
+ raise UnsupportedPatternError(
2267
+ f"Returning constructor call '{func_name}(...)' is not supported",
2268
+ RECOMMENDATIONS["constructor_return"],
2269
+ line=line,
2270
+ col=col,
2271
+ )
2272
+
2273
+ def _check_constructor_in_assignment(self, node: ast.expr) -> None:
2274
+ """Check for constructor calls in assignments.
2275
+
2276
+ Raises UnsupportedPatternError if the assignment value is a class instantiation
2277
+ like: result = MyModel(field=value)
2278
+
2279
+ This is not supported because the workflow IR cannot serialize arbitrary
2280
+ object instantiation. Users should use an @action to create objects.
2281
+ """
2282
+ # Skip if it's an await (action call) - those are fine
2283
+ if isinstance(node, ast.Await):
2284
+ return
2285
+
2286
+ # Check for direct Call that looks like a constructor
2287
+ if isinstance(node, ast.Call):
2288
+ func_name = self._get_constructor_name(node.func)
2289
+ if func_name and self._looks_like_constructor(func_name, node):
2290
+ line = getattr(node, "lineno", None)
2291
+ col = getattr(node, "col_offset", None)
2292
+ raise UnsupportedPatternError(
2293
+ f"Assigning constructor call '{func_name}(...)' is not supported",
2294
+ RECOMMENDATIONS["constructor_assignment"],
2295
+ line=line,
2296
+ col=col,
2297
+ )
2298
+
2299
+ def _get_constructor_name(self, func: ast.expr) -> Optional[str]:
2300
+ """Get the name from a function expression if it looks like a constructor."""
2301
+ if isinstance(func, ast.Name):
2302
+ return func.id
2303
+ elif isinstance(func, ast.Attribute):
2304
+ return func.attr
2305
+ return None
2306
+
2307
+ def _looks_like_constructor(self, func_name: str, call: ast.Call) -> bool:
2308
+ """Check if a function call looks like a class constructor.
2309
+
2310
+ A constructor is identified by:
2311
+ 1. Name starts with uppercase (PEP8 convention for classes)
2312
+ 2. It's not a known action
2313
+ 3. It's not a known builtin like String operations
2314
+ 4. It's not a known Pydantic model or dataclass (those are allowed)
2315
+
2316
+ This is a heuristic - we can't perfectly distinguish constructors
2317
+ from functions without full type information.
2318
+ """
2319
+ # Check if first letter is uppercase (class naming convention)
2320
+ if not func_name or not func_name[0].isupper():
2321
+ return False
2322
+
2323
+ # If it's a known action, it's not a constructor
2324
+ if func_name in self._action_defs:
2325
+ return False
2326
+
2327
+ # If it's a known Pydantic model or dataclass, allow it
2328
+ # (it will be converted to a dict expression)
2329
+ if func_name in self._model_defs:
2330
+ return False
2331
+
2332
+ # Common builtins that start with uppercase but aren't constructors
2333
+ # (these are rarely used in workflow code but let's be safe)
2334
+ builtin_exceptions = {"True", "False", "None", "Ellipsis"}
2335
+ if func_name in builtin_exceptions:
2336
+ return False
2337
+
2338
+ return True
2339
+
2340
+ def _is_model_constructor(self, node: ast.expr) -> Optional[str]:
2341
+ """Check if an expression is a Pydantic model or dataclass constructor call.
2342
+
2343
+ Returns the model name if it is, None otherwise.
2344
+ """
2345
+ if not isinstance(node, ast.Call):
2346
+ return None
2347
+
2348
+ func_name = self._get_constructor_name(node.func)
2349
+ if func_name and func_name in self._model_defs:
2350
+ return func_name
2351
+
2352
+ return None
2353
+
2354
+ def _convert_model_constructor_to_dict(self, node: ast.Call, model_name: str) -> ir.Expr:
2355
+ """Convert a Pydantic model or dataclass constructor call to a dict expression.
2356
+
2357
+ For example:
2358
+ MyModel(field1=value1, field2=value2)
2359
+ becomes:
2360
+ {"field1": value1, "field2": value2}
2361
+
2362
+ Default values from the model definition are included for fields not
2363
+ explicitly provided in the constructor call.
2364
+ """
2365
+ model_def = self._model_defs[model_name]
2366
+ entries: List[ir.DictEntry] = []
2367
+
2368
+ # Track which fields were explicitly provided
2369
+ provided_fields: Set[str] = set()
2370
+
2371
+ # First, add all explicitly provided kwargs
2372
+ for kw in node.keywords:
2373
+ if kw.arg is None:
2374
+ # **kwargs expansion - not supported
2375
+ line = getattr(node, "lineno", None)
2376
+ col = getattr(node, "col_offset", None)
2377
+ raise UnsupportedPatternError(
2378
+ f"Model constructor '{model_name}' with **kwargs is not supported",
2379
+ "Use explicit keyword arguments instead of **kwargs.",
2380
+ line=line,
2381
+ col=col,
2382
+ )
2383
+
2384
+ provided_fields.add(kw.arg)
2385
+ key_expr = ir.Expr()
2386
+ key_literal = ir.Literal()
2387
+ key_literal.string_value = kw.arg
2388
+ key_expr.literal.CopyFrom(key_literal)
2389
+
2390
+ value_expr = self._expr_to_ir_with_model_coercion(kw.value)
2391
+ if value_expr is None:
2392
+ # If we can't convert the value, we need to raise an error
2393
+ line = getattr(node, "lineno", None)
2394
+ col = getattr(node, "col_offset", None)
2395
+ raise UnsupportedPatternError(
2396
+ f"Cannot convert value for field '{kw.arg}' in '{model_name}'",
2397
+ "Use simpler expressions (literals, variables, dicts, lists).",
2398
+ line=line,
2399
+ col=col,
2400
+ )
2401
+
2402
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2403
+
2404
+ # Handle positional arguments (dataclasses support this)
2405
+ if node.args:
2406
+ # For dataclasses, positional args map to fields in order
2407
+ field_names = list(model_def.fields.keys())
2408
+ for i, arg in enumerate(node.args):
2409
+ if i >= len(field_names):
2410
+ line = getattr(node, "lineno", None)
2411
+ col = getattr(node, "col_offset", None)
2412
+ raise UnsupportedPatternError(
2413
+ f"Too many positional arguments for '{model_name}'",
2414
+ "Use keyword arguments for clarity.",
2415
+ line=line,
2416
+ col=col,
2417
+ )
2418
+
2419
+ field_name = field_names[i]
2420
+ provided_fields.add(field_name)
2421
+
2422
+ key_expr = ir.Expr()
2423
+ key_literal = ir.Literal()
2424
+ key_literal.string_value = field_name
2425
+ key_expr.literal.CopyFrom(key_literal)
2426
+
2427
+ value_expr = self._expr_to_ir_with_model_coercion(arg)
2428
+ if value_expr is None:
2429
+ line = getattr(node, "lineno", None)
2430
+ col = getattr(node, "col_offset", None)
2431
+ raise UnsupportedPatternError(
2432
+ f"Cannot convert positional argument for field '{field_name}' in '{model_name}'",
2433
+ "Use simpler expressions (literals, variables, dicts, lists).",
2434
+ line=line,
2435
+ col=col,
2436
+ )
2437
+
2438
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2439
+
2440
+ # Add default values for fields not explicitly provided
2441
+ for field_name, field_def in model_def.fields.items():
2442
+ if field_name in provided_fields:
2443
+ continue
2444
+
2445
+ if field_def.has_default:
2446
+ key_expr = ir.Expr()
2447
+ key_literal = ir.Literal()
2448
+ key_literal.string_value = field_name
2449
+ key_expr.literal.CopyFrom(key_literal)
2450
+
2451
+ # Convert the default value to an IR literal
2452
+ default_literal = _constant_to_literal(field_def.default_value)
2453
+ if default_literal is None:
2454
+ # Can't serialize this default - skip it
2455
+ # (it's probably a complex object like a list factory)
2456
+ continue
2457
+
2458
+ value_expr = ir.Expr()
2459
+ value_expr.literal.CopyFrom(default_literal)
2460
+
2461
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
2462
+
2463
+ result = ir.Expr(span=_make_span(node))
2464
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
2465
+ return result
2466
+
2467
+ def _check_non_action_await(self, node: ast.Await) -> None:
2468
+ """Check if an await is for a non-action function.
2469
+
2470
+ Note: We can only reliably detect non-action awaits for functions defined
2471
+ in the same module. Actions imported from other modules will pass through
2472
+ and may fail at runtime if they're not actually actions.
2473
+
2474
+ For now, we only check against common builtins and known non-action patterns.
2475
+ A runtime check will catch functions that aren't registered actions.
2476
+ """
2477
+ awaited = node.value
2478
+ if not isinstance(awaited, ast.Call):
2479
+ return
2480
+
2481
+ # Skip special cases that are handled elsewhere
2482
+ if self._is_run_action_call(awaited):
2483
+ return
2484
+ if self._is_asyncio_sleep_call(awaited):
2485
+ return
2486
+ if self._is_asyncio_gather_call(awaited):
2487
+ return
2488
+
2489
+ # Get the function name
2490
+ func_name = None
2491
+ if isinstance(awaited.func, ast.Name):
2492
+ func_name = awaited.func.id
2493
+ elif isinstance(awaited.func, ast.Attribute):
2494
+ func_name = awaited.func.attr
2495
+
2496
+ if not func_name:
2497
+ return
2498
+
2499
+ # Only raise error for functions defined in THIS module that we know
2500
+ # are NOT actions (i.e., async functions without @action decorator)
2501
+ # We can't reliably detect imported non-actions without full type info.
2502
+ #
2503
+ # The check works by looking at _module_functions which contains
2504
+ # functions defined in the same module as the workflow.
2505
+ if func_name in getattr(self, "_module_functions", set()):
2506
+ if func_name not in self._action_defs:
2507
+ line = getattr(node, "lineno", None)
2508
+ col = getattr(node, "col_offset", None)
2509
+ raise UnsupportedPatternError(
2510
+ f"Awaiting non-action function '{func_name}()' is not supported",
2511
+ RECOMMENDATIONS["non_action_call"],
2512
+ line=line,
2513
+ col=col,
2514
+ )
2515
+
2516
+ def _check_sync_function_call(self, node: ast.Call) -> None:
2517
+ """Check for synchronous function calls that should be in actions.
2518
+
2519
+ Common patterns like len(), str(), etc. are not supported in workflow code.
2520
+ """
2521
+ func_name = None
2522
+ if isinstance(node.func, ast.Name):
2523
+ func_name = node.func.id
2524
+ elif isinstance(node.func, ast.Attribute):
2525
+ # Method calls on objects - check the method name
2526
+ func_name = node.func.attr
2527
+
2528
+ if not func_name:
2529
+ return
2530
+
2531
+ # Builtins that users commonly try to use
2532
+ common_builtins = {
2533
+ "len",
2534
+ "str",
2535
+ "int",
2536
+ "float",
2537
+ "bool",
2538
+ "list",
2539
+ "dict",
2540
+ "set",
2541
+ "tuple",
2542
+ "sum",
2543
+ "min",
2544
+ "max",
2545
+ "sorted",
2546
+ "reversed",
2547
+ "enumerate",
2548
+ "zip",
2549
+ "map",
2550
+ "filter",
2551
+ "range",
2552
+ "abs",
2553
+ "round",
2554
+ "print",
2555
+ "type",
2556
+ "isinstance",
2557
+ "hasattr",
2558
+ "getattr",
2559
+ "setattr",
2560
+ "open",
2561
+ "format",
2562
+ }
2563
+
2564
+ if func_name in common_builtins:
2565
+ line = getattr(node, "lineno", None)
2566
+ col = getattr(node, "col_offset", None)
2567
+ raise UnsupportedPatternError(
2568
+ f"Calling built-in function '{func_name}()' directly is not supported",
2569
+ RECOMMENDATIONS["builtin_call"],
2570
+ line=line,
2571
+ col=col,
2572
+ )
2573
+
2574
+ def _extract_action_call(self, node: ast.expr) -> Optional[ir.ActionCall]:
2575
+ """Extract an action call from an expression if present.
2576
+
2577
+ Also validates that awaited calls are actually @action decorated functions.
2578
+ Raises UnsupportedPatternError if awaiting a non-action function.
2579
+ """
2580
+ if not isinstance(node, ast.Await):
2581
+ return None
2582
+
2583
+ awaited = node.value
2584
+ # Handle self.run_action(...) wrapper
2585
+ if isinstance(awaited, ast.Call):
2586
+ if self._is_run_action_call(awaited):
2587
+ # Extract the actual action call from run_action
2588
+ if awaited.args:
2589
+ action_call = self._extract_action_call_from_awaitable(awaited.args[0])
2590
+ if action_call:
2591
+ # Extract policies from run_action kwargs (retry, timeout)
2592
+ self._extract_policies_from_run_action(awaited, action_call)
2593
+ return action_call
2594
+ # Check for asyncio.sleep() - convert to @sleep action
2595
+ if self._is_asyncio_sleep_call(awaited):
2596
+ return self._convert_asyncio_sleep_to_action(awaited)
2597
+ # Try to extract as action call
2598
+ action_call = self._extract_action_call_from_call(awaited)
2599
+ if action_call:
2600
+ return action_call
2601
+
2602
+ # If we get here, it's an await of a non-action function
2603
+ self._check_non_action_await(node)
2604
+ return None
2605
+
2606
+ return None
2607
+
2608
+ def _is_run_action_call(self, node: ast.Call) -> bool:
2609
+ """Check if this is a self.run_action(...) call."""
2610
+ if isinstance(node.func, ast.Attribute):
2611
+ return node.func.attr == "run_action"
2612
+ return False
2613
+
2614
+ def _extract_policies_from_run_action(
2615
+ self, run_action_call: ast.Call, action_call: ir.ActionCall
2616
+ ) -> None:
2617
+ """Extract retry and timeout policies from run_action kwargs.
2618
+
2619
+ Parses patterns like:
2620
+ - self.run_action(action(), retry=RetryPolicy(attempts=3))
2621
+ - self.run_action(action(), timeout=timedelta(seconds=30))
2622
+ - self.run_action(action(), timeout=60)
2623
+ """
2624
+ for kw in run_action_call.keywords:
2625
+ if kw.arg == "retry":
2626
+ retry_policy = self._parse_retry_policy(kw.value)
2627
+ if retry_policy:
2628
+ policy_bracket = ir.PolicyBracket()
2629
+ policy_bracket.retry.CopyFrom(retry_policy)
2630
+ action_call.policies.append(policy_bracket)
2631
+ elif kw.arg == "timeout":
2632
+ timeout_policy = self._parse_timeout_policy(kw.value)
2633
+ if timeout_policy:
2634
+ policy_bracket = ir.PolicyBracket()
2635
+ policy_bracket.timeout.CopyFrom(timeout_policy)
2636
+ action_call.policies.append(policy_bracket)
2637
+
2638
+ def _parse_retry_policy(self, node: ast.expr) -> Optional[ir.RetryPolicy]:
2639
+ """Parse a RetryPolicy(...) call into IR.
2640
+
2641
+ Supports:
2642
+ - RetryPolicy(attempts=3)
2643
+ - RetryPolicy(attempts=3, exception_types=["ValueError"])
2644
+ - RetryPolicy(attempts=3, backoff_seconds=5)
2645
+ - self.retry_policy (instance attribute reference)
2646
+ """
2647
+ # Handle self.attr pattern - look up in instance attrs
2648
+ if (
2649
+ isinstance(node, ast.Attribute)
2650
+ and isinstance(node.value, ast.Name)
2651
+ and node.value.id == "self"
2652
+ ):
2653
+ attr_name = node.attr
2654
+ if attr_name in self._instance_attrs:
2655
+ return self._parse_retry_policy(self._instance_attrs[attr_name])
2656
+ return None
2657
+
2658
+ if not isinstance(node, ast.Call):
2659
+ return None
2660
+
2661
+ # Check if it's a RetryPolicy call
2662
+ func_name = None
2663
+ if isinstance(node.func, ast.Name):
2664
+ func_name = node.func.id
2665
+ elif isinstance(node.func, ast.Attribute):
2666
+ func_name = node.func.attr
2667
+
2668
+ if func_name != "RetryPolicy":
2669
+ return None
2670
+
2671
+ policy = ir.RetryPolicy()
2672
+
2673
+ for kw in node.keywords:
2674
+ if kw.arg == "attempts" and isinstance(kw.value, ast.Constant):
2675
+ # attempts means total executions, max_retries means retries after first attempt
2676
+ # So attempts=1 -> max_retries=0 (no retries), attempts=3 -> max_retries=2
2677
+ policy.max_retries = kw.value.value - 1
2678
+ elif kw.arg == "exception_types" and isinstance(kw.value, ast.List):
2679
+ for elt in kw.value.elts:
2680
+ if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
2681
+ policy.exception_types.append(elt.value)
2682
+ elif kw.arg == "backoff_seconds" and isinstance(kw.value, ast.Constant):
2683
+ policy.backoff.seconds = int(kw.value.value)
2684
+
2685
+ return policy
2686
+
2687
+ def _parse_timeout_policy(self, node: ast.expr) -> Optional[ir.TimeoutPolicy]:
2688
+ """Parse a timeout value into IR.
2689
+
2690
+ Supports:
2691
+ - timeout=60 (int seconds)
2692
+ - timeout=30.5 (float seconds)
2693
+ - timeout=timedelta(seconds=30)
2694
+ - timeout=timedelta(minutes=2)
2695
+ - self.timeout (instance attribute reference)
2696
+ """
2697
+ # Handle self.attr pattern - look up in instance attrs
2698
+ if (
2699
+ isinstance(node, ast.Attribute)
2700
+ and isinstance(node.value, ast.Name)
2701
+ and node.value.id == "self"
2702
+ ):
2703
+ attr_name = node.attr
2704
+ if attr_name in self._instance_attrs:
2705
+ return self._parse_timeout_policy(self._instance_attrs[attr_name])
2706
+ return None
2707
+
2708
+ policy = ir.TimeoutPolicy()
2709
+
2710
+ # Direct numeric value (seconds)
2711
+ if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
2712
+ policy.timeout.seconds = int(node.value)
2713
+ return policy
2714
+
2715
+ # timedelta(...) call
2716
+ if isinstance(node, ast.Call):
2717
+ func_name = None
2718
+ if isinstance(node.func, ast.Name):
2719
+ func_name = node.func.id
2720
+ elif isinstance(node.func, ast.Attribute):
2721
+ func_name = node.func.attr
2722
+
2723
+ if func_name == "timedelta":
2724
+ total_seconds = 0
2725
+ for kw in node.keywords:
2726
+ if isinstance(kw.value, ast.Constant):
2727
+ val = kw.value.value
2728
+ if kw.arg == "seconds":
2729
+ total_seconds += int(val)
2730
+ elif kw.arg == "minutes":
2731
+ total_seconds += int(val) * 60
2732
+ elif kw.arg == "hours":
2733
+ total_seconds += int(val) * 3600
2734
+ elif kw.arg == "days":
2735
+ total_seconds += int(val) * 86400
2736
+ policy.timeout.seconds = total_seconds
2737
+ return policy
2738
+
2739
+ return None
2740
+
2741
+ def _is_asyncio_sleep_call(self, node: ast.Call) -> bool:
2742
+ """Check if this is an asyncio.sleep(...) call.
2743
+
2744
+ Supports both patterns:
2745
+ - import asyncio; asyncio.sleep(1)
2746
+ - from asyncio import sleep; sleep(1)
2747
+ - from asyncio import sleep as s; s(1)
2748
+ """
2749
+ if isinstance(node.func, ast.Attribute):
2750
+ # asyncio.sleep(...) pattern
2751
+ if node.func.attr == "sleep" and isinstance(node.func.value, ast.Name):
2752
+ return node.func.value.id == "asyncio"
2753
+ elif isinstance(node.func, ast.Name):
2754
+ # sleep(...) pattern - check if it's imported from asyncio
2755
+ func_name = node.func.id
2756
+ if func_name in self._imported_names:
2757
+ imported = self._imported_names[func_name]
2758
+ return imported.module == "asyncio" and imported.original_name == "sleep"
2759
+ return False
2760
+
2761
+ def _convert_asyncio_sleep_to_action(self, node: ast.Call) -> ir.ActionCall:
2762
+ """Convert asyncio.sleep(duration) to @sleep(duration=X) action call.
2763
+
2764
+ This creates a built-in sleep action that the scheduler handles as a
2765
+ durable sleep - stored in the DB with a future scheduled_at time.
2766
+ """
2767
+ action_call = ir.ActionCall(action_name="sleep")
2768
+
2769
+ # Extract duration argument (positional or keyword)
2770
+ if node.args:
2771
+ # asyncio.sleep(1) - positional
2772
+ expr = self._expr_to_ir_with_model_coercion(node.args[0])
2773
+ if expr:
2774
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2775
+ elif node.keywords:
2776
+ # asyncio.sleep(seconds=1) - keyword (less common)
2777
+ for kw in node.keywords:
2778
+ if kw.arg in ("seconds", "delay", "duration"):
2779
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
2780
+ if expr:
2781
+ action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
2782
+ break
2783
+
2784
+ return action_call
2785
+
2786
+ def _is_asyncio_gather_call(self, node: ast.Call) -> bool:
2787
+ """Check if this is an asyncio.gather(...) call.
2788
+
2789
+ Supports both patterns:
2790
+ - import asyncio; asyncio.gather(a(), b())
2791
+ - from asyncio import gather; gather(a(), b())
2792
+ - from asyncio import gather as g; g(a(), b())
2793
+ """
2794
+ if isinstance(node.func, ast.Attribute):
2795
+ # asyncio.gather(...) pattern
2796
+ if node.func.attr == "gather" and isinstance(node.func.value, ast.Name):
2797
+ return node.func.value.id == "asyncio"
2798
+ elif isinstance(node.func, ast.Name):
2799
+ # gather(...) pattern - check if it's imported from asyncio
2800
+ func_name = node.func.id
2801
+ if func_name in self._imported_names:
2802
+ imported = self._imported_names[func_name]
2803
+ return imported.module == "asyncio" and imported.original_name == "gather"
2804
+ return False
2805
+
2806
+ def _convert_asyncio_gather(
2807
+ self, node: ast.Call
2808
+ ) -> Optional[Union[ir.ParallelExpr, ir.SpreadExpr]]:
2809
+ """Convert asyncio.gather(...) to ParallelExpr or SpreadExpr IR.
2810
+
2811
+ Handles two patterns:
2812
+ 1. Static gather: asyncio.gather(a(), b(), c()) -> ParallelExpr
2813
+ 2. Spread gather: asyncio.gather(*[action(x) for x in items]) -> SpreadExpr
2814
+
2815
+ Args:
2816
+ node: The asyncio.gather() Call node
2817
+
2818
+ Returns:
2819
+ A ParallelExpr, SpreadExpr, or None if conversion fails.
2820
+ """
2821
+ return_exceptions_value: Optional[ast.expr] = None
2822
+ for kw in node.keywords:
2823
+ if kw.arg is None:
2824
+ line = node.lineno if hasattr(node, "lineno") else None
2825
+ col = node.col_offset if hasattr(node, "col_offset") else None
2826
+ raise UnsupportedPatternError(
2827
+ "asyncio.gather() must be called with return_exceptions=True",
2828
+ RECOMMENDATIONS["gather_return_exceptions"],
2829
+ line=line,
2830
+ col=col,
2831
+ )
2832
+ if kw.arg == "return_exceptions":
2833
+ return_exceptions_value = kw.value
2834
+ break
2835
+
2836
+ if not (
2837
+ isinstance(return_exceptions_value, ast.Constant)
2838
+ and return_exceptions_value.value is True
2839
+ ):
2840
+ line = node.lineno if hasattr(node, "lineno") else None
2841
+ col = node.col_offset if hasattr(node, "col_offset") else None
2842
+ raise UnsupportedPatternError(
2843
+ "asyncio.gather() must be called with return_exceptions=True",
2844
+ RECOMMENDATIONS["gather_return_exceptions"],
2845
+ line=line,
2846
+ col=col,
2847
+ )
2848
+
2849
+ # Check for starred expressions - spread pattern
2850
+ if len(node.args) == 1 and isinstance(node.args[0], ast.Starred):
2851
+ starred = node.args[0]
2852
+ # Only list comprehensions are supported for spread
2853
+ if isinstance(starred.value, ast.ListComp):
2854
+ return self._convert_listcomp_to_spread_expr(starred.value)
2855
+ else:
2856
+ # Spreading a variable or other expression is not supported
2857
+ line = getattr(node, "lineno", None)
2858
+ col = getattr(node, "col_offset", None)
2859
+ if isinstance(starred.value, ast.Name):
2860
+ var_name = starred.value.id
2861
+ raise UnsupportedPatternError(
2862
+ f"Spreading variable '{var_name}' in asyncio.gather() is not supported",
2863
+ RECOMMENDATIONS["gather_variable_spread"],
2864
+ line=line,
2865
+ col=col,
2866
+ )
2867
+ else:
2868
+ raise UnsupportedPatternError(
2869
+ "Spreading non-list-comprehension expressions in asyncio.gather() is not supported",
2870
+ RECOMMENDATIONS["gather_variable_spread"],
2871
+ line=line,
2872
+ col=col,
2873
+ )
2874
+
2875
+ # Standard case: gather(a(), b(), c()) -> ParallelExpr
2876
+ parallel = ir.ParallelExpr()
2877
+
2878
+ # Each argument to gather() should be an action call
2879
+ for arg in node.args:
2880
+ call = self._convert_gather_arg_to_call(arg)
2881
+ if call:
2882
+ parallel.calls.append(call)
2883
+
2884
+ # Only return if we have calls
2885
+ if not parallel.calls:
2886
+ return None
2887
+
2888
+ return parallel
2889
+
2890
+ def _convert_listcomp_to_spread_expr(self, listcomp: ast.ListComp) -> Optional[ir.SpreadExpr]:
2891
+ """Convert a list comprehension to SpreadExpr IR.
2892
+
2893
+ Handles patterns like:
2894
+ [action(x=item) for item in collection]
2895
+ [self.run_action(action(x=item), retry=..., timeout=...) for item in collection]
2896
+
2897
+ The comprehension must have exactly one generator with no conditions,
2898
+ and the element must be an action call (optionally wrapped in run_action).
2899
+
2900
+ Args:
2901
+ listcomp: The ListComp AST node
2902
+
2903
+ Returns:
2904
+ A SpreadExpr, or None if conversion fails.
2905
+ """
2906
+ # Only support simple list comprehensions with one generator
2907
+ if len(listcomp.generators) != 1:
2908
+ line = getattr(listcomp, "lineno", None)
2909
+ col = getattr(listcomp, "col_offset", None)
2910
+ raise UnsupportedPatternError(
2911
+ "Spread pattern only supports a single loop variable",
2912
+ "Use a simple list comprehension: [action(x) for x in items]",
2913
+ line=line,
2914
+ col=col,
2915
+ )
2916
+
2917
+ gen = listcomp.generators[0]
2918
+
2919
+ # Check for conditions - not supported
2920
+ if gen.ifs:
2921
+ line = getattr(listcomp, "lineno", None)
2922
+ col = getattr(listcomp, "col_offset", None)
2923
+ raise UnsupportedPatternError(
2924
+ "Spread pattern does not support conditions in list comprehension",
2925
+ "Remove the 'if' clause from the comprehension",
2926
+ line=line,
2927
+ col=col,
2928
+ )
2929
+
2930
+ # Get the loop variable name
2931
+ if not isinstance(gen.target, ast.Name):
2932
+ line = getattr(listcomp, "lineno", None)
2933
+ col = getattr(listcomp, "col_offset", None)
2934
+ raise UnsupportedPatternError(
2935
+ "Spread pattern requires a simple loop variable",
2936
+ "Use a simple variable: [action(x) for x in items]",
2937
+ line=line,
2938
+ col=col,
2939
+ )
2940
+ loop_var = gen.target.id
2941
+
2942
+ # Get the collection expression
2943
+ collection_expr = self._expr_to_ir_with_model_coercion(gen.iter)
2944
+ if not collection_expr:
2945
+ line = getattr(listcomp, "lineno", None)
2946
+ col = getattr(listcomp, "col_offset", None)
2947
+ raise UnsupportedPatternError(
2948
+ "Could not convert collection expression in spread pattern",
2949
+ "Ensure the collection is a simple variable or expression",
2950
+ line=line,
2951
+ col=col,
2952
+ )
2953
+
2954
+ # The element must be a call (either action call or run_action wrapper)
2955
+ if not isinstance(listcomp.elt, ast.Call):
2956
+ line = getattr(listcomp, "lineno", None)
2957
+ col = getattr(listcomp, "col_offset", None)
2958
+ raise UnsupportedPatternError(
2959
+ "Spread pattern requires an action call in the list comprehension",
2960
+ "Use: [action(x=item) for item in items]",
2961
+ line=line,
2962
+ col=col,
2963
+ )
2964
+
2965
+ # Check for self.run_action(...) wrapper pattern
2966
+ action_call: Optional[ir.ActionCall] = None
2967
+ if self._is_run_action_call(listcomp.elt):
2968
+ # Extract the inner action call from run_action's first argument
2969
+ if listcomp.elt.args:
2970
+ inner_call = listcomp.elt.args[0]
2971
+ if isinstance(inner_call, ast.Call):
2972
+ action_call = self._extract_action_call_from_call(inner_call)
2973
+ if action_call:
2974
+ # Extract policies (retry, timeout) from run_action kwargs
2975
+ self._extract_policies_from_run_action(listcomp.elt, action_call)
2976
+ else:
2977
+ # Direct action call
2978
+ action_call = self._extract_action_call_from_call(listcomp.elt)
2979
+
2980
+ if not action_call:
2981
+ line = getattr(listcomp, "lineno", None)
2982
+ col = getattr(listcomp, "col_offset", None)
2983
+ raise UnsupportedPatternError(
2984
+ "Spread pattern element must be an @action call",
2985
+ "Ensure the function is decorated with @action, or use self.run_action(action(...), ...)",
2986
+ line=line,
2987
+ col=col,
2988
+ )
2989
+
2990
+ # Build the SpreadExpr
2991
+ spread = ir.SpreadExpr()
2992
+ spread.collection.CopyFrom(collection_expr)
2993
+ spread.loop_var = loop_var
2994
+ spread.action.CopyFrom(action_call)
2995
+
2996
+ return spread
2997
+
2998
+ def _convert_gather_arg_to_call(self, node: ast.expr) -> Optional[ir.Call]:
2999
+ """Convert a gather argument to an IR Call.
3000
+
3001
+ Handles both action calls and regular function calls.
3002
+ """
3003
+ if not isinstance(node, ast.Call):
3004
+ return None
3005
+
3006
+ # Try to extract as an action call first
3007
+ action_call = self._extract_action_call_from_call(node)
3008
+ if action_call:
3009
+ call = ir.Call()
3010
+ call.action.CopyFrom(action_call)
3011
+ return call
3012
+
3013
+ # Fall back to regular function call
3014
+ func_call = self._convert_to_function_call(node)
3015
+ if func_call:
3016
+ call = ir.Call()
3017
+ call.function.CopyFrom(func_call)
3018
+ return call
3019
+
3020
+ return None
3021
+
3022
+ def _convert_to_function_call(self, node: ast.Call) -> Optional[ir.FunctionCall]:
3023
+ """Convert an AST Call to IR FunctionCall."""
3024
+ func_name = self._get_func_name(node.func)
3025
+ if not func_name:
3026
+ return None
3027
+
3028
+ fn_call = ir.FunctionCall(name=func_name)
3029
+ global_function = _global_function_for_call(func_name, node)
3030
+ if global_function is not None:
3031
+ fn_call.global_function = global_function
3032
+
3033
+ # Add positional args
3034
+ for arg in node.args:
3035
+ expr = self._expr_to_ir_with_model_coercion(arg)
3036
+ if expr:
3037
+ fn_call.args.append(expr)
3038
+
3039
+ # Add keyword args
3040
+ for kw in node.keywords:
3041
+ if kw.arg:
3042
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
3043
+ if expr:
3044
+ fn_call.kwargs.append(ir.Kwarg(name=kw.arg, value=expr))
3045
+
3046
+ # Fill in missing kwargs with default values from function signature
3047
+ self._fill_default_kwargs_for_expr(fn_call)
3048
+
3049
+ return fn_call
3050
+
3051
+ def _get_func_name(self, node: ast.expr) -> Optional[str]:
3052
+ """Get function name from a func node."""
3053
+ if isinstance(node, ast.Name):
3054
+ return node.id
3055
+ elif isinstance(node, ast.Attribute):
3056
+ # Handle chained attributes like obj.method
3057
+ parts = []
3058
+ current = node
3059
+ while isinstance(current, ast.Attribute):
3060
+ parts.append(current.attr)
3061
+ current = current.value
3062
+ if isinstance(current, ast.Name):
3063
+ parts.append(current.id)
3064
+ name = ".".join(reversed(parts))
3065
+ if name.startswith("self."):
3066
+ return name[5:]
3067
+ return name
3068
+ return None
3069
+
3070
+ def _convert_model_constructor_if_needed(self, node: ast.Call) -> Optional[ir.Expr]:
3071
+ model_name = self._is_model_constructor(node)
3072
+ if model_name:
3073
+ return self._convert_model_constructor_to_dict(node, model_name)
3074
+ return None
3075
+
3076
+ def _resolve_enum_attribute(self, node: ast.Attribute) -> Optional[ir.Expr]:
3077
+ value = _resolve_enum_attribute_value(node, self._module_globals)
3078
+ if value is None:
3079
+ return None
3080
+ literal = _constant_to_literal(value)
3081
+ if literal is None:
3082
+ line = getattr(node, "lineno", None)
3083
+ col = getattr(node, "col_offset", None)
3084
+ raise UnsupportedPatternError(
3085
+ "Enum value must be a primitive literal",
3086
+ RECOMMENDATIONS["unsupported_literal"],
3087
+ line=line,
3088
+ col=col,
3089
+ )
3090
+ expr = ir.Expr(span=_make_span(node))
3091
+ expr.literal.CopyFrom(literal)
3092
+ return expr
3093
+
3094
+ def _is_exception_class(self, class_name: str) -> bool:
3095
+ """Check if a class name refers to an exception class.
3096
+
3097
+ This checks both built-in exceptions and imported exception classes.
3098
+ """
3099
+ # Check built-in exceptions first
3100
+ builtin_exc = (
3101
+ getattr(__builtins__, class_name, None)
3102
+ if isinstance(__builtins__, dict)
3103
+ else getattr(__builtins__, class_name, None)
3104
+ )
3105
+ if builtin_exc is None:
3106
+ # Try getting from builtins module directly
3107
+ import builtins
3108
+
3109
+ builtin_exc = getattr(builtins, class_name, None)
3110
+ if (
3111
+ builtin_exc is not None
3112
+ and isinstance(builtin_exc, type)
3113
+ and issubclass(builtin_exc, BaseException)
3114
+ ):
3115
+ return True
3116
+
3117
+ # Check module globals for imported exception classes
3118
+ cls = self._module_globals.get(class_name)
3119
+ if cls is not None and isinstance(cls, type) and issubclass(cls, BaseException):
3120
+ return True
3121
+
3122
+ return False
3123
+
3124
+ def _expr_to_ir_with_model_coercion(self, node: ast.expr) -> Optional[ir.Expr]:
3125
+ """Convert an AST expression to IR, converting model constructors to dicts."""
3126
+ result = _expr_to_ir(
3127
+ node,
3128
+ model_converter=self._convert_model_constructor_if_needed,
3129
+ enum_resolver=self._resolve_enum_attribute,
3130
+ exception_class_resolver=self._is_exception_class,
3131
+ )
3132
+ # Post-process to fill in default kwargs for function calls (recursively)
3133
+ if result is not None:
3134
+ self._fill_default_kwargs_recursive(result)
3135
+ return result
3136
+
3137
+ def _fill_default_kwargs_recursive(self, expr: ir.Expr) -> None:
3138
+ """Recursively fill in default kwargs for all function calls in an expression."""
3139
+ if expr.HasField("function_call"):
3140
+ self._fill_default_kwargs_for_expr(expr.function_call)
3141
+ # Recurse into function call args and kwargs
3142
+ for arg in expr.function_call.args:
3143
+ self._fill_default_kwargs_recursive(arg)
3144
+ for kwarg in expr.function_call.kwargs:
3145
+ if kwarg.value:
3146
+ self._fill_default_kwargs_recursive(kwarg.value)
3147
+ elif expr.HasField("binary_op"):
3148
+ if expr.binary_op.left:
3149
+ self._fill_default_kwargs_recursive(expr.binary_op.left)
3150
+ if expr.binary_op.right:
3151
+ self._fill_default_kwargs_recursive(expr.binary_op.right)
3152
+ elif expr.HasField("unary_op"):
3153
+ if expr.unary_op.operand:
3154
+ self._fill_default_kwargs_recursive(expr.unary_op.operand)
3155
+ elif expr.HasField("list"):
3156
+ for elem in expr.list.elements:
3157
+ self._fill_default_kwargs_recursive(elem)
3158
+ elif expr.HasField("dict"):
3159
+ for entry in expr.dict.entries:
3160
+ if entry.key:
3161
+ self._fill_default_kwargs_recursive(entry.key)
3162
+ if entry.value:
3163
+ self._fill_default_kwargs_recursive(entry.value)
3164
+ elif expr.HasField("index"):
3165
+ if expr.index.object:
3166
+ self._fill_default_kwargs_recursive(expr.index.object)
3167
+ if expr.index.index:
3168
+ self._fill_default_kwargs_recursive(expr.index.index)
3169
+ elif expr.HasField("dot"):
3170
+ if expr.dot.object:
3171
+ self._fill_default_kwargs_recursive(expr.dot.object)
3172
+
3173
+ def _fill_default_kwargs_for_expr(self, fn_call: ir.FunctionCall) -> None:
3174
+ """Fill in missing kwargs with default values from the function signature."""
3175
+ sig_info = self._ctx.function_signatures.get(fn_call.name)
3176
+ if sig_info is None:
3177
+ return
3178
+
3179
+ # Track which parameters are already provided
3180
+ provided_by_position: Set[str] = set()
3181
+ provided_by_kwarg: Set[str] = set()
3182
+
3183
+ # Positional args map to parameters in order
3184
+ for idx, _arg in enumerate(fn_call.args):
3185
+ if idx < len(sig_info.parameters):
3186
+ provided_by_position.add(sig_info.parameters[idx].name)
3187
+
3188
+ # Kwargs are named
3189
+ for kwarg in fn_call.kwargs:
3190
+ provided_by_kwarg.add(kwarg.name)
3191
+
3192
+ # Add defaults for missing parameters
3193
+ for param in sig_info.parameters:
3194
+ if param.name in provided_by_position or param.name in provided_by_kwarg:
3195
+ continue
3196
+ if not param.has_default:
3197
+ continue
3198
+
3199
+ # Convert the default value to an IR expression
3200
+ literal = _constant_to_literal(param.default_value)
3201
+ if literal is not None:
3202
+ expr = ir.Expr()
3203
+ expr.literal.CopyFrom(literal)
3204
+ fn_call.kwargs.append(ir.Kwarg(name=param.name, value=expr))
3205
+
3206
+ def _extract_action_call_from_awaitable(self, node: ast.expr) -> Optional[ir.ActionCall]:
3207
+ """Extract action call from an awaitable expression."""
3208
+ if isinstance(node, ast.Call):
3209
+ return self._extract_action_call_from_call(node)
3210
+ return None
3211
+
3212
+ def _extract_action_call_from_call(self, node: ast.Call) -> Optional[ir.ActionCall]:
3213
+ """Extract action call info from a Call node.
3214
+
3215
+ Converts positional arguments to keyword arguments using the action's
3216
+ signature introspection. This ensures all arguments are named in the IR.
3217
+
3218
+ Pydantic models and dataclass constructors passed as arguments are
3219
+ automatically converted to dict expressions.
3220
+ """
3221
+ action_name = self._get_action_name(node.func)
3222
+ if not action_name:
3223
+ return None
3224
+
3225
+ if action_name not in self._action_defs:
3226
+ return None
3227
+
3228
+ action_def = self._action_defs[action_name]
3229
+ action_call = ir.ActionCall(action_name=action_def.action_name)
3230
+
3231
+ # Set the module name so the worker knows where to find the action
3232
+ if action_def.module_name:
3233
+ action_call.module_name = action_def.module_name
3234
+
3235
+ # Get parameter names from signature for positional arg conversion
3236
+ param_names = list(action_def.signature.parameters.keys())
3237
+
3238
+ # Convert positional args to kwargs using signature introspection
3239
+ # Model constructors are converted to dict expressions
3240
+ for i, arg in enumerate(node.args):
3241
+ if i < len(param_names):
3242
+ expr = self._expr_to_ir_with_model_coercion(arg)
3243
+ if expr:
3244
+ kwarg = ir.Kwarg(name=param_names[i], value=expr)
3245
+ action_call.kwargs.append(kwarg)
3246
+
3247
+ # Add explicit kwargs
3248
+ # Model constructors are converted to dict expressions
3249
+ for kw in node.keywords:
3250
+ if kw.arg:
3251
+ expr = self._expr_to_ir_with_model_coercion(kw.value)
3252
+ if expr:
3253
+ kwarg = ir.Kwarg(name=kw.arg, value=expr)
3254
+ action_call.kwargs.append(kwarg)
3255
+
3256
+ return action_call
3257
+
3258
+ def _get_action_name(self, func: ast.expr) -> Optional[str]:
3259
+ """Get the action name from a function expression."""
3260
+ if isinstance(func, ast.Name):
3261
+ return func.id
3262
+ elif isinstance(func, ast.Attribute):
3263
+ return func.attr
3264
+ return None
3265
+
3266
+ def _get_assign_target(self, targets: List[ast.expr]) -> Optional[str]:
3267
+ """Get the target variable name from assignment targets (single target only)."""
3268
+ if targets and isinstance(targets[0], ast.Name):
3269
+ return targets[0].id
3270
+ return None
3271
+
3272
+ def _get_assign_targets(self, targets: List[ast.expr]) -> List[str]:
3273
+ """Get all target variable names from assignment targets (including tuple unpacking)."""
3274
+ result: List[str] = []
3275
+ for t in targets:
3276
+ if isinstance(t, ast.Name):
3277
+ result.append(t.id)
3278
+ elif isinstance(t, ast.Subscript):
3279
+ formatted = _format_subscript_target(t)
3280
+ if formatted:
3281
+ result.append(formatted)
3282
+ elif isinstance(t, ast.Tuple):
3283
+ for elt in t.elts:
3284
+ if isinstance(elt, ast.Name):
3285
+ result.append(elt.id)
3286
+ return result
3287
+
3288
+
3289
+ def _make_span(node: ast.AST) -> ir.Span:
3290
+ """Create a Span from an AST node."""
3291
+ return ir.Span(
3292
+ start_line=getattr(node, "lineno", 0),
3293
+ start_col=getattr(node, "col_offset", 0),
3294
+ end_line=getattr(node, "end_lineno", 0) or 0,
3295
+ end_col=getattr(node, "end_col_offset", 0) or 0,
3296
+ )
3297
+
3298
+
3299
+ def _attribute_chain(node: ast.Attribute) -> Optional[List[str]]:
3300
+ parts: List[str] = []
3301
+ current: ast.AST = node
3302
+ while isinstance(current, ast.Attribute):
3303
+ parts.append(current.attr)
3304
+ current = current.value
3305
+ if isinstance(current, ast.Name):
3306
+ parts.append(current.id)
3307
+ return list(reversed(parts))
3308
+ return None
3309
+
3310
+
3311
+ def _resolve_enum_attribute_value(
3312
+ node: ast.Attribute,
3313
+ module_globals: Mapping[str, Any],
3314
+ ) -> Optional[Any]:
3315
+ chain = _attribute_chain(node)
3316
+ if not chain or len(chain) < 2:
3317
+ return None
3318
+
3319
+ current = module_globals.get(chain[0])
3320
+ if current is None:
3321
+ return None
3322
+
3323
+ for part in chain[1:-1]:
3324
+ try:
3325
+ current_dict = current.__dict__
3326
+ except AttributeError:
3327
+ return None
3328
+ current = current_dict.get(part)
3329
+ if current is None:
3330
+ return None
3331
+
3332
+ member_name = chain[-1]
3333
+ if isinstance(current, EnumMeta):
3334
+ member = current.__members__.get(member_name)
3335
+ if member is None:
3336
+ return None
3337
+ return member.value
3338
+
3339
+ return None
3340
+
3341
+
3342
+ def _try_convert_isinstance_to_isexception(
3343
+ expr: ast.Call,
3344
+ exception_class_resolver: Callable[[str], bool],
3345
+ model_converter: Optional[Callable[[ast.Call], Optional[ir.Expr]]] = None,
3346
+ enum_resolver: Optional[Callable[[ast.Attribute], Optional[ir.Expr]]] = None,
3347
+ ) -> Optional[ir.Expr]:
3348
+ """Try to convert isinstance(x, ExceptionClass) to isexception(x, "ExceptionClass").
3349
+
3350
+ Returns None if this is not an isinstance call or if the class is not an exception.
3351
+ Raises UnsupportedPatternError if isinstance is used with a non-exception class.
3352
+ """
3353
+ # Check if this is an isinstance call
3354
+ func_name = _get_func_name(expr.func)
3355
+ if func_name != "isinstance":
3356
+ return None
3357
+
3358
+ # isinstance requires exactly 2 positional arguments
3359
+ if len(expr.args) != 2 or expr.keywords:
3360
+ line = expr.lineno if hasattr(expr, "lineno") else None
3361
+ col = expr.col_offset if hasattr(expr, "col_offset") else None
3362
+ raise UnsupportedPatternError(
3363
+ "isinstance() requires exactly 2 positional arguments",
3364
+ RECOMMENDATIONS["sync_function_call"],
3365
+ line=line,
3366
+ col=col,
3367
+ )
3368
+
3369
+ # Extract exception class names from the second argument
3370
+ second_arg = expr.args[1]
3371
+ exception_names: List[str] = []
3372
+
3373
+ if isinstance(second_arg, ast.Name):
3374
+ # Single class: isinstance(x, ValueError)
3375
+ class_name = second_arg.id
3376
+ if not exception_class_resolver(class_name):
3377
+ line = expr.lineno if hasattr(expr, "lineno") else None
3378
+ col = expr.col_offset if hasattr(expr, "col_offset") else None
3379
+ raise UnsupportedPatternError(
3380
+ f"isinstance() with non-exception class '{class_name}' is not supported in workflows",
3381
+ "isinstance() can only be used to check exception types in workflow code. "
3382
+ "Move other type checks to an @action.",
3383
+ line=line,
3384
+ col=col,
3385
+ )
3386
+ exception_names.append(class_name)
3387
+ elif isinstance(second_arg, ast.Tuple):
3388
+ # Tuple of classes: isinstance(x, (ValueError, TypeError))
3389
+ for elt in second_arg.elts:
3390
+ if not isinstance(elt, ast.Name):
3391
+ line = expr.lineno if hasattr(expr, "lineno") else None
3392
+ col = expr.col_offset if hasattr(expr, "col_offset") else None
3393
+ raise UnsupportedPatternError(
3394
+ "isinstance() class argument must be a simple name or tuple of names",
3395
+ "Use simple class names like ValueError or (ValueError, TypeError).",
3396
+ line=line,
3397
+ col=col,
3398
+ )
3399
+ class_name = elt.id
3400
+ if not exception_class_resolver(class_name):
3401
+ line = expr.lineno if hasattr(expr, "lineno") else None
3402
+ col = expr.col_offset if hasattr(expr, "col_offset") else None
3403
+ raise UnsupportedPatternError(
3404
+ f"isinstance() with non-exception class '{class_name}' is not supported in workflows",
3405
+ "isinstance() can only be used to check exception types in workflow code. "
3406
+ "Move other type checks to an @action.",
3407
+ line=line,
3408
+ col=col,
3409
+ )
3410
+ exception_names.append(class_name)
3411
+ else:
3412
+ line = expr.lineno if hasattr(expr, "lineno") else None
3413
+ col = expr.col_offset if hasattr(expr, "col_offset") else None
3414
+ raise UnsupportedPatternError(
3415
+ "isinstance() class argument must be a simple name or tuple of names",
3416
+ "Use simple class names like ValueError or (ValueError, TypeError).",
3417
+ line=line,
3418
+ col=col,
3419
+ )
3420
+
3421
+ # Convert the first argument (the value being checked) to IR
3422
+ value_expr = _expr_to_ir(
3423
+ expr.args[0],
3424
+ model_converter=model_converter,
3425
+ enum_resolver=enum_resolver,
3426
+ exception_class_resolver=exception_class_resolver,
3427
+ )
3428
+ if not value_expr:
3429
+ return None
3430
+
3431
+ # Build the isexception call
3432
+ # If single exception: isexception(x, "ValueError")
3433
+ # If multiple: isexception(x, ["ValueError", "TypeError"])
3434
+ result = ir.Expr(span=_make_span(expr))
3435
+
3436
+ if len(exception_names) == 1:
3437
+ # Single exception name as string
3438
+ type_arg = ir.Expr(span=_make_span(second_arg))
3439
+ type_arg.literal.CopyFrom(ir.Literal(string_value=exception_names[0]))
3440
+ args = [value_expr, type_arg]
3441
+ else:
3442
+ # Multiple exception names as list of strings
3443
+ type_elements = []
3444
+ for name in exception_names:
3445
+ elem = ir.Expr()
3446
+ elem.literal.CopyFrom(ir.Literal(string_value=name))
3447
+ type_elements.append(elem)
3448
+ type_arg = ir.Expr(span=_make_span(second_arg))
3449
+ type_arg.list.CopyFrom(ir.ListExpr(elements=type_elements))
3450
+ args = [value_expr, type_arg]
3451
+
3452
+ func_call = ir.FunctionCall(
3453
+ name="isexception",
3454
+ args=args,
3455
+ kwargs=[],
3456
+ )
3457
+ func_call.global_function = ir.GlobalFunction.GLOBAL_FUNCTION_ISEXCEPTION
3458
+ result.function_call.CopyFrom(func_call)
3459
+ return result
3460
+
3461
+
3462
+ def _expr_to_ir(
3463
+ expr: ast.AST,
3464
+ model_converter: Optional[Callable[[ast.Call], Optional[ir.Expr]]] = None,
3465
+ enum_resolver: Optional[Callable[[ast.Attribute], Optional[ir.Expr]]] = None,
3466
+ exception_class_resolver: Optional[Callable[[str], bool]] = None,
3467
+ ) -> Optional[ir.Expr]:
3468
+ """Convert Python AST expression to IR Expr.
3469
+
3470
+ Args:
3471
+ expr: The AST expression to convert.
3472
+ model_converter: Optional callback to convert model constructors.
3473
+ enum_resolver: Optional callback to resolve enum attributes.
3474
+ exception_class_resolver: Optional callback that takes a class name and
3475
+ returns True if it's an exception class. Used to transform
3476
+ isinstance(x, ExceptionClass) to isexception(x, "ExceptionClass").
3477
+ """
3478
+ result = ir.Expr(span=_make_span(expr))
3479
+
3480
+ if isinstance(expr, ast.Call) and model_converter:
3481
+ converted = model_converter(expr)
3482
+ if converted:
3483
+ return converted
3484
+
3485
+ # Handle isinstance(x, ExceptionClass) -> isexception(x, "ExceptionClass")
3486
+ if isinstance(expr, ast.Call) and exception_class_resolver:
3487
+ isinstance_result = _try_convert_isinstance_to_isexception(
3488
+ expr, exception_class_resolver, model_converter, enum_resolver
3489
+ )
3490
+ if isinstance_result is not None:
3491
+ return isinstance_result
3492
+
3493
+ if isinstance(expr, ast.Name):
3494
+ result.variable.CopyFrom(ir.Variable(name=expr.id))
3495
+ return result
3496
+
3497
+ if isinstance(expr, ast.Constant):
3498
+ literal = _constant_to_literal(expr.value)
3499
+ if literal:
3500
+ result.literal.CopyFrom(literal)
3501
+ return result
3502
+
3503
+ if isinstance(expr, ast.BinOp):
3504
+ left = _expr_to_ir(
3505
+ expr.left,
3506
+ model_converter=model_converter,
3507
+ enum_resolver=enum_resolver,
3508
+ exception_class_resolver=exception_class_resolver,
3509
+ )
3510
+ right = _expr_to_ir(
3511
+ expr.right,
3512
+ model_converter=model_converter,
3513
+ enum_resolver=enum_resolver,
3514
+ exception_class_resolver=exception_class_resolver,
3515
+ )
3516
+ op = _bin_op_to_ir(expr.op)
3517
+ if left and right and op:
3518
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
3519
+ return result
3520
+
3521
+ if isinstance(expr, ast.UnaryOp):
3522
+ operand = _expr_to_ir(
3523
+ expr.operand,
3524
+ model_converter=model_converter,
3525
+ enum_resolver=enum_resolver,
3526
+ exception_class_resolver=exception_class_resolver,
3527
+ )
3528
+ op = _unary_op_to_ir(expr.op)
3529
+ if operand and op:
3530
+ result.unary_op.CopyFrom(ir.UnaryOp(op=op, operand=operand))
3531
+ return result
3532
+
3533
+ if isinstance(expr, ast.Compare):
3534
+ left = _expr_to_ir(
3535
+ expr.left,
3536
+ model_converter=model_converter,
3537
+ enum_resolver=enum_resolver,
3538
+ exception_class_resolver=exception_class_resolver,
3539
+ )
3540
+ if not left:
3541
+ return None
3542
+ # For simplicity, handle single comparison
3543
+ if expr.ops and expr.comparators:
3544
+ op = _cmp_op_to_ir(expr.ops[0])
3545
+ right = _expr_to_ir(
3546
+ expr.comparators[0],
3547
+ model_converter=model_converter,
3548
+ enum_resolver=enum_resolver,
3549
+ )
3550
+ if op and right:
3551
+ result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
3552
+ return result
3553
+
3554
+ if isinstance(expr, ast.BoolOp):
3555
+ values = [
3556
+ _expr_to_ir(
3557
+ v,
3558
+ model_converter=model_converter,
3559
+ enum_resolver=enum_resolver,
3560
+ exception_class_resolver=exception_class_resolver,
3561
+ )
3562
+ for v in expr.values
3563
+ ]
3564
+ if all(v for v in values):
3565
+ op = _bool_op_to_ir(expr.op)
3566
+ if op and len(values) >= 2:
3567
+ # Chain boolean ops: a and b and c -> (a and b) and c
3568
+ result_expr = values[0]
3569
+ for v in values[1:]:
3570
+ if result_expr and v:
3571
+ new_result = ir.Expr()
3572
+ new_result.binary_op.CopyFrom(ir.BinaryOp(left=result_expr, op=op, right=v))
3573
+ result_expr = new_result
3574
+ return result_expr
3575
+
3576
+ if isinstance(expr, ast.List):
3577
+ elements = [
3578
+ _expr_to_ir(
3579
+ e,
3580
+ model_converter=model_converter,
3581
+ enum_resolver=enum_resolver,
3582
+ exception_class_resolver=exception_class_resolver,
3583
+ )
3584
+ for e in expr.elts
3585
+ ]
3586
+ if all(e for e in elements):
3587
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
3588
+ result.list.CopyFrom(list_expr)
3589
+ return result
3590
+
3591
+ if isinstance(expr, ast.Dict):
3592
+ entries: List[ir.DictEntry] = []
3593
+ for k, v in zip(expr.keys, expr.values, strict=False):
3594
+ if k:
3595
+ key_expr = _expr_to_ir(
3596
+ k,
3597
+ model_converter=model_converter,
3598
+ enum_resolver=enum_resolver,
3599
+ )
3600
+ value_expr = _expr_to_ir(
3601
+ v,
3602
+ model_converter=model_converter,
3603
+ enum_resolver=enum_resolver,
3604
+ )
3605
+ if key_expr and value_expr:
3606
+ entries.append(ir.DictEntry(key=key_expr, value=value_expr))
3607
+ result.dict.CopyFrom(ir.DictExpr(entries=entries))
3608
+ return result
3609
+
3610
+ if isinstance(expr, ast.Subscript):
3611
+ obj = _expr_to_ir(
3612
+ expr.value,
3613
+ model_converter=model_converter,
3614
+ enum_resolver=enum_resolver,
3615
+ exception_class_resolver=exception_class_resolver,
3616
+ )
3617
+ index = (
3618
+ _expr_to_ir(
3619
+ expr.slice,
3620
+ model_converter=model_converter,
3621
+ enum_resolver=enum_resolver,
3622
+ )
3623
+ if isinstance(expr.slice, ast.AST)
3624
+ else None
3625
+ )
3626
+ if obj and index:
3627
+ result.index.CopyFrom(ir.IndexAccess(object=obj, index=index))
3628
+ return result
3629
+
3630
+ if isinstance(expr, ast.Attribute):
3631
+ if enum_resolver:
3632
+ resolved = enum_resolver(expr)
3633
+ if resolved:
3634
+ return resolved
3635
+ obj = _expr_to_ir(
3636
+ expr.value,
3637
+ model_converter=model_converter,
3638
+ enum_resolver=enum_resolver,
3639
+ exception_class_resolver=exception_class_resolver,
3640
+ )
3641
+ if obj:
3642
+ result.dot.CopyFrom(ir.DotAccess(object=obj, attribute=expr.attr))
3643
+ return result
3644
+
3645
+ if isinstance(expr, ast.Await) and isinstance(expr.value, ast.Call):
3646
+ func_name = _get_func_name(expr.value.func)
3647
+ if func_name:
3648
+ args = [
3649
+ _expr_to_ir(
3650
+ a,
3651
+ model_converter=model_converter,
3652
+ enum_resolver=enum_resolver,
3653
+ )
3654
+ for a in expr.value.args
3655
+ ]
3656
+ kwargs: List[ir.Kwarg] = []
3657
+ for kw in expr.value.keywords:
3658
+ if kw.arg:
3659
+ kw_expr = _expr_to_ir(
3660
+ kw.value,
3661
+ model_converter=model_converter,
3662
+ enum_resolver=enum_resolver,
3663
+ )
3664
+ if kw_expr:
3665
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
3666
+ func_call = ir.FunctionCall(
3667
+ name=func_name,
3668
+ args=[a for a in args if a],
3669
+ kwargs=kwargs,
3670
+ )
3671
+ global_function = _global_function_for_call(func_name, expr.value)
3672
+ if global_function is not None:
3673
+ func_call.global_function = global_function
3674
+ result.function_call.CopyFrom(func_call)
3675
+ return result
3676
+
3677
+ if isinstance(expr, ast.Call):
3678
+ # Function call
3679
+ if not _is_self_method_call(expr):
3680
+ func_name = _get_func_name(expr.func) or "unknown"
3681
+ if isinstance(expr.func, ast.Attribute):
3682
+ line = expr.lineno if hasattr(expr, "lineno") else None
3683
+ col = expr.col_offset if hasattr(expr, "col_offset") else None
3684
+ raise UnsupportedPatternError(
3685
+ f"Calling synchronous function '{func_name}()' directly is not supported",
3686
+ RECOMMENDATIONS["sync_function_call"],
3687
+ line=line,
3688
+ col=col,
3689
+ )
3690
+ if func_name not in ALLOWED_SYNC_FUNCTIONS:
3691
+ line = expr.lineno if hasattr(expr, "lineno") else None
3692
+ col = expr.col_offset if hasattr(expr, "col_offset") else None
3693
+ raise UnsupportedPatternError(
3694
+ f"Calling synchronous function '{func_name}()' directly is not supported",
3695
+ RECOMMENDATIONS["sync_function_call"],
3696
+ line=line,
3697
+ col=col,
3698
+ )
3699
+ func_name = _get_func_name(expr.func)
3700
+ if func_name:
3701
+ args = [
3702
+ _expr_to_ir(
3703
+ a,
3704
+ model_converter=model_converter,
3705
+ enum_resolver=enum_resolver,
3706
+ )
3707
+ for a in expr.args
3708
+ ]
3709
+ kwargs: List[ir.Kwarg] = []
3710
+ for kw in expr.keywords:
3711
+ if kw.arg:
3712
+ kw_expr = _expr_to_ir(
3713
+ kw.value,
3714
+ model_converter=model_converter,
3715
+ enum_resolver=enum_resolver,
3716
+ )
3717
+ if kw_expr:
3718
+ kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
3719
+ func_call = ir.FunctionCall(
3720
+ name=func_name,
3721
+ args=[a for a in args if a],
3722
+ kwargs=kwargs,
3723
+ )
3724
+ global_function = _global_function_for_call(func_name, expr)
3725
+ if global_function is not None:
3726
+ func_call.global_function = global_function
3727
+ result.function_call.CopyFrom(func_call)
3728
+ return result
3729
+
3730
+ if isinstance(expr, ast.Tuple):
3731
+ # Handle tuple as list for now
3732
+ elements = [
3733
+ _expr_to_ir(
3734
+ e,
3735
+ model_converter=model_converter,
3736
+ enum_resolver=enum_resolver,
3737
+ exception_class_resolver=exception_class_resolver,
3738
+ )
3739
+ for e in expr.elts
3740
+ ]
3741
+ if all(e for e in elements):
3742
+ list_expr = ir.ListExpr(elements=[e for e in elements if e])
3743
+ result.list.CopyFrom(list_expr)
3744
+ return result
3745
+
3746
+ # Check for unsupported expression types
3747
+ _check_unsupported_expression(expr)
3748
+
3749
+ return None
3750
+
3751
+
3752
+ def _check_unsupported_expression(expr: ast.AST) -> None:
3753
+ """Check for unsupported expression types and raise descriptive errors."""
3754
+ line = getattr(expr, "lineno", None)
3755
+ col = getattr(expr, "col_offset", None)
3756
+
3757
+ if isinstance(expr, ast.Constant):
3758
+ if _constant_to_literal(expr.value) is None:
3759
+ raise UnsupportedPatternError(
3760
+ f"Unsupported literal type '{type(expr.value).__name__}'",
3761
+ RECOMMENDATIONS["unsupported_literal"],
3762
+ line=line,
3763
+ col=col,
3764
+ )
3765
+
3766
+ if isinstance(expr, ast.JoinedStr):
3767
+ raise UnsupportedPatternError(
3768
+ "F-strings are not supported",
3769
+ RECOMMENDATIONS["fstring"],
3770
+ line=line,
3771
+ col=col,
3772
+ )
3773
+ elif isinstance(expr, ast.Lambda):
3774
+ raise UnsupportedPatternError(
3775
+ "Lambda expressions are not supported",
3776
+ RECOMMENDATIONS["lambda"],
3777
+ line=line,
3778
+ col=col,
3779
+ )
3780
+ elif isinstance(expr, ast.ListComp):
3781
+ raise UnsupportedPatternError(
3782
+ "List comprehensions are not supported in this context",
3783
+ RECOMMENDATIONS["list_comprehension"],
3784
+ line=line,
3785
+ col=col,
3786
+ )
3787
+ elif isinstance(expr, ast.DictComp):
3788
+ raise UnsupportedPatternError(
3789
+ "Dict comprehensions are not supported in this context",
3790
+ RECOMMENDATIONS["dict_comprehension"],
3791
+ line=line,
3792
+ col=col,
3793
+ )
3794
+ elif isinstance(expr, ast.SetComp):
3795
+ raise UnsupportedPatternError(
3796
+ "Set comprehensions are not supported",
3797
+ RECOMMENDATIONS["set_comprehension"],
3798
+ line=line,
3799
+ col=col,
3800
+ )
3801
+ elif isinstance(expr, ast.GeneratorExp):
3802
+ raise UnsupportedPatternError(
3803
+ "Generator expressions are not supported",
3804
+ RECOMMENDATIONS["generator"],
3805
+ line=line,
3806
+ col=col,
3807
+ )
3808
+ elif isinstance(expr, ast.NamedExpr):
3809
+ raise UnsupportedPatternError(
3810
+ "The walrus operator (:=) is not supported",
3811
+ RECOMMENDATIONS["walrus"],
3812
+ line=line,
3813
+ col=col,
3814
+ )
3815
+ elif isinstance(expr, ast.Yield) or isinstance(expr, ast.YieldFrom):
3816
+ raise UnsupportedPatternError(
3817
+ "Yield expressions are not supported",
3818
+ RECOMMENDATIONS["yield_statement"],
3819
+ line=line,
3820
+ col=col,
3821
+ )
3822
+ elif isinstance(expr, ast.expr):
3823
+ raise UnsupportedPatternError(
3824
+ f"Unsupported expression type '{type(expr).__name__}'",
3825
+ RECOMMENDATIONS["unsupported_expression"],
3826
+ line=line,
3827
+ col=col,
3828
+ )
3829
+
3830
+
3831
+ def _format_subscript_target(target: ast.Subscript) -> Optional[str]:
3832
+ """Convert a subscript target to a string representation for tracking."""
3833
+ if not isinstance(target.value, ast.Name):
3834
+ return None
3835
+
3836
+ base = target.value.id
3837
+ try:
3838
+ index_str = ast.unparse(target.slice)
3839
+ except Exception:
3840
+ return None
3841
+
3842
+ return f"{base}[{index_str}]"
3843
+
3844
+
3845
+ def _constant_to_literal(value: Any) -> Optional[ir.Literal]:
3846
+ """Convert a Python constant to IR Literal."""
3847
+ literal = ir.Literal()
3848
+ if value is None:
3849
+ literal.is_none = True
3850
+ elif isinstance(value, bool):
3851
+ literal.bool_value = value
3852
+ elif isinstance(value, int):
3853
+ literal.int_value = value
3854
+ elif isinstance(value, float):
3855
+ literal.float_value = value
3856
+ elif isinstance(value, str):
3857
+ literal.string_value = value
3858
+ else:
3859
+ return None
3860
+ return literal
3861
+
3862
+
3863
+ def _bin_op_to_ir(op: ast.operator) -> Optional[ir.BinaryOperator]:
3864
+ """Convert Python binary operator to IR BinaryOperator."""
3865
+ mapping = {
3866
+ ast.Add: ir.BinaryOperator.BINARY_OP_ADD,
3867
+ ast.Sub: ir.BinaryOperator.BINARY_OP_SUB,
3868
+ ast.Mult: ir.BinaryOperator.BINARY_OP_MUL,
3869
+ ast.Div: ir.BinaryOperator.BINARY_OP_DIV,
3870
+ ast.FloorDiv: ir.BinaryOperator.BINARY_OP_FLOOR_DIV,
3871
+ ast.Mod: ir.BinaryOperator.BINARY_OP_MOD,
3872
+ }
3873
+ return mapping.get(type(op))
3874
+
3875
+
3876
+ def _unary_op_to_ir(op: ast.unaryop) -> Optional[ir.UnaryOperator]:
3877
+ """Convert Python unary operator to IR UnaryOperator."""
3878
+ mapping = {
3879
+ ast.USub: ir.UnaryOperator.UNARY_OP_NEG,
3880
+ ast.Not: ir.UnaryOperator.UNARY_OP_NOT,
3881
+ }
3882
+ return mapping.get(type(op))
3883
+
3884
+
3885
+ def _cmp_op_to_ir(op: ast.cmpop) -> Optional[ir.BinaryOperator]:
3886
+ """Convert Python comparison operator to IR BinaryOperator."""
3887
+ mapping = {
3888
+ ast.Eq: ir.BinaryOperator.BINARY_OP_EQ,
3889
+ ast.NotEq: ir.BinaryOperator.BINARY_OP_NE,
3890
+ ast.Lt: ir.BinaryOperator.BINARY_OP_LT,
3891
+ ast.LtE: ir.BinaryOperator.BINARY_OP_LE,
3892
+ ast.Gt: ir.BinaryOperator.BINARY_OP_GT,
3893
+ ast.GtE: ir.BinaryOperator.BINARY_OP_GE,
3894
+ ast.In: ir.BinaryOperator.BINARY_OP_IN,
3895
+ ast.NotIn: ir.BinaryOperator.BINARY_OP_NOT_IN,
3896
+ }
3897
+ return mapping.get(type(op))
3898
+
3899
+
3900
+ def _bool_op_to_ir(op: ast.boolop) -> Optional[ir.BinaryOperator]:
3901
+ """Convert Python boolean operator to IR BinaryOperator."""
3902
+ mapping = {
3903
+ ast.And: ir.BinaryOperator.BINARY_OP_AND,
3904
+ ast.Or: ir.BinaryOperator.BINARY_OP_OR,
3905
+ }
3906
+ return mapping.get(type(op))
3907
+
3908
+
3909
+ def _get_func_name(func: ast.expr) -> Optional[str]:
3910
+ """Get function name from a Call's func attribute."""
3911
+ if isinstance(func, ast.Name):
3912
+ return func.id
3913
+ elif isinstance(func, ast.Attribute):
3914
+ # For method calls like obj.method, return full dotted name
3915
+ parts = []
3916
+ current = func
3917
+ while isinstance(current, ast.Attribute):
3918
+ parts.append(current.attr)
3919
+ current = current.value
3920
+ if isinstance(current, ast.Name):
3921
+ parts.append(current.id)
3922
+ name = ".".join(reversed(parts))
3923
+ if name.startswith("self."):
3924
+ return name[5:]
3925
+ return name
3926
+ return None
3927
+
3928
+
3929
+ def _is_self_method_call(node: ast.Call) -> bool:
3930
+ """Return True if the call is a direct self.method(...) invocation."""
3931
+ func = node.func
3932
+ return (
3933
+ isinstance(func, ast.Attribute)
3934
+ and isinstance(func.value, ast.Name)
3935
+ and func.value.id == "self"
3936
+ )
3937
+
3938
+
3939
+ def _global_function_for_call(
3940
+ func_name: str, node: ast.Call
3941
+ ) -> Optional[ir.GlobalFunction.ValueType]:
3942
+ """Return the GlobalFunction enum value for supported globals."""
3943
+ if _is_self_method_call(node):
3944
+ return None
3945
+ return GLOBAL_FUNCTIONS.get(func_name)