rappel 0.7.2__py3-none-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of rappel might be problematic. Click here for more details.
- proto/ast_pb2.py +121 -0
- proto/ast_pb2.pyi +1627 -0
- proto/ast_pb2_grpc.py +24 -0
- proto/ast_pb2_grpc.pyi +22 -0
- proto/messages_pb2.py +106 -0
- proto/messages_pb2.pyi +1231 -0
- proto/messages_pb2_grpc.py +406 -0
- proto/messages_pb2_grpc.pyi +380 -0
- rappel/__init__.py +61 -0
- rappel/actions.py +108 -0
- rappel/bin/boot-rappel-singleton +0 -0
- rappel/bin/rappel-bridge +0 -0
- rappel/bin/start-workers +0 -0
- rappel/bridge.py +228 -0
- rappel/dependencies.py +149 -0
- rappel/exceptions.py +18 -0
- rappel/formatter.py +110 -0
- rappel/ir_builder.py +3664 -0
- rappel/logger.py +39 -0
- rappel/registry.py +111 -0
- rappel/schedule.py +362 -0
- rappel/serialization.py +273 -0
- rappel/worker.py +191 -0
- rappel/workflow.py +251 -0
- rappel/workflow_runtime.py +287 -0
- rappel-0.7.2.data/scripts/boot-rappel-singleton +0 -0
- rappel-0.7.2.data/scripts/rappel-bridge +0 -0
- rappel-0.7.2.data/scripts/start-workers +0 -0
- rappel-0.7.2.dist-info/METADATA +308 -0
- rappel-0.7.2.dist-info/RECORD +32 -0
- rappel-0.7.2.dist-info/WHEEL +4 -0
- rappel-0.7.2.dist-info/entry_points.txt +2 -0
rappel/ir_builder.py
ADDED
|
@@ -0,0 +1,3664 @@
|
|
|
1
|
+
"""
|
|
2
|
+
IR Builder - Converts Python workflow AST to Rappel IR (ast.proto).
|
|
3
|
+
|
|
4
|
+
This module parses Python workflow classes and produces the IR representation
|
|
5
|
+
that can be sent to the Rust runtime for execution.
|
|
6
|
+
|
|
7
|
+
The IR builder performs deep transformations to convert Python patterns into
|
|
8
|
+
valid Rappel IR structures. Control flow bodies (try, for, if branches) are
|
|
9
|
+
represented as full blocks, so they can contain arbitrary statements and
|
|
10
|
+
`return` behaves consistently with Python (returns from the entry function end
|
|
11
|
+
the workflow).
|
|
12
|
+
|
|
13
|
+
Validation:
|
|
14
|
+
The IR builder proactively detects unsupported Python patterns and raises
|
|
15
|
+
UnsupportedPatternError with clear recommendations for how to rewrite the code.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
import ast
|
|
21
|
+
import copy
|
|
22
|
+
import inspect
|
|
23
|
+
import textwrap
|
|
24
|
+
from dataclasses import dataclass
|
|
25
|
+
from enum import EnumMeta
|
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, NoReturn, Optional, Set, Union
|
|
27
|
+
|
|
28
|
+
from proto import ast_pb2 as ir
|
|
29
|
+
from rappel.registry import registry
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class UnsupportedPatternError(Exception):
|
|
33
|
+
"""Raised when the IR builder encounters an unsupported Python pattern.
|
|
34
|
+
|
|
35
|
+
This error includes a recommendation for how to rewrite the code to use
|
|
36
|
+
supported patterns.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
message: str,
|
|
42
|
+
recommendation: str,
|
|
43
|
+
line: Optional[int] = None,
|
|
44
|
+
col: Optional[int] = None,
|
|
45
|
+
filename: Optional[str] = None,
|
|
46
|
+
):
|
|
47
|
+
self.message = message
|
|
48
|
+
self.recommendation = recommendation
|
|
49
|
+
self.line = line
|
|
50
|
+
self.col = col
|
|
51
|
+
self.filename = filename
|
|
52
|
+
|
|
53
|
+
location_parts: List[str] = []
|
|
54
|
+
if filename:
|
|
55
|
+
location_parts.append(filename)
|
|
56
|
+
if line:
|
|
57
|
+
location_parts.append(f"line {line}")
|
|
58
|
+
if col is not None:
|
|
59
|
+
location_parts.append(f"col {col}")
|
|
60
|
+
location = f" ({', '.join(location_parts)})" if location_parts else ""
|
|
61
|
+
full_message = f"{message}{location}\n\nRecommendation: {recommendation}"
|
|
62
|
+
super().__init__(full_message)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# Recommendations for common unsupported patterns
|
|
66
|
+
RECOMMENDATIONS = {
|
|
67
|
+
"constructor_return": (
|
|
68
|
+
"Returning a class constructor (like MyModel(...)) directly is not supported.\n"
|
|
69
|
+
"The workflow IR cannot serialize arbitrary object instantiation.\n\n"
|
|
70
|
+
"Use an @action to create the object:\n\n"
|
|
71
|
+
" @action\n"
|
|
72
|
+
" async def build_result(items: list, count: int) -> MyResult:\n"
|
|
73
|
+
" return MyResult(items=items, count=count)\n\n"
|
|
74
|
+
" # In workflow:\n"
|
|
75
|
+
" return await build_result(items, count)"
|
|
76
|
+
),
|
|
77
|
+
"constructor_assignment": (
|
|
78
|
+
"Assigning a class constructor result (like x = MyClass(...)) is not supported.\n"
|
|
79
|
+
"The workflow IR cannot serialize arbitrary object instantiation.\n\n"
|
|
80
|
+
"Use an @action to create the object:\n\n"
|
|
81
|
+
" @action\n"
|
|
82
|
+
" async def create_config(value: int) -> Config:\n"
|
|
83
|
+
" return Config(value=value)\n\n"
|
|
84
|
+
" # In workflow:\n"
|
|
85
|
+
" config = await create_config(value)"
|
|
86
|
+
),
|
|
87
|
+
"non_action_call": (
|
|
88
|
+
"Calling a function that is not decorated with @action is not supported.\n"
|
|
89
|
+
"Only @action decorated functions can be awaited in workflow code.\n\n"
|
|
90
|
+
"Add the @action decorator to your function:\n\n"
|
|
91
|
+
" @action\n"
|
|
92
|
+
" async def my_function(x: int) -> int:\n"
|
|
93
|
+
" return x * 2"
|
|
94
|
+
),
|
|
95
|
+
"sync_function_call": (
|
|
96
|
+
"Calling a synchronous function directly in workflow code is not supported.\n"
|
|
97
|
+
"All computation must happen inside @action decorated async functions.\n\n"
|
|
98
|
+
"Wrap your logic in an @action:\n\n"
|
|
99
|
+
" @action\n"
|
|
100
|
+
" async def compute(x: int) -> int:\n"
|
|
101
|
+
" return some_sync_function(x)"
|
|
102
|
+
),
|
|
103
|
+
"method_call_non_self": (
|
|
104
|
+
"Calling methods on objects other than 'self' is not supported in workflow code.\n"
|
|
105
|
+
"Use an @action to perform method calls:\n\n"
|
|
106
|
+
" @action\n"
|
|
107
|
+
" async def call_method(obj: MyClass) -> Result:\n"
|
|
108
|
+
" return obj.some_method()"
|
|
109
|
+
),
|
|
110
|
+
"builtin_call": (
|
|
111
|
+
"Calling built-in functions like len(), str(), int() directly is not supported.\n"
|
|
112
|
+
"Use an @action to perform these operations:\n\n"
|
|
113
|
+
" @action\n"
|
|
114
|
+
" async def get_length(items: list) -> int:\n"
|
|
115
|
+
" return len(items)"
|
|
116
|
+
),
|
|
117
|
+
"fstring": (
|
|
118
|
+
"F-strings are not supported in workflow code because they require "
|
|
119
|
+
"runtime string interpolation.\n"
|
|
120
|
+
"Use an @action to perform string formatting:\n\n"
|
|
121
|
+
" @action\n"
|
|
122
|
+
" async def format_message(value: int) -> str:\n"
|
|
123
|
+
" return f'Result: {value}'"
|
|
124
|
+
),
|
|
125
|
+
"delete": (
|
|
126
|
+
"The 'del' statement is not supported in workflow code.\n"
|
|
127
|
+
"Use an @action to perform mutations:\n\n"
|
|
128
|
+
" @action\n"
|
|
129
|
+
" async def remove_key(data: dict, key: str) -> dict:\n"
|
|
130
|
+
" del data[key]\n"
|
|
131
|
+
" return data"
|
|
132
|
+
),
|
|
133
|
+
"while_loop": (
|
|
134
|
+
"While loops are not supported in workflow code because they can run "
|
|
135
|
+
"indefinitely.\n"
|
|
136
|
+
"Use a for loop with a fixed range, or restructure as recursive workflow calls."
|
|
137
|
+
),
|
|
138
|
+
"with_statement": (
|
|
139
|
+
"Context managers (with statements) are not supported in workflow code.\n"
|
|
140
|
+
"Use an @action to handle resource management:\n\n"
|
|
141
|
+
" @action\n"
|
|
142
|
+
" async def read_file(path: str) -> str:\n"
|
|
143
|
+
" with open(path) as f:\n"
|
|
144
|
+
" return f.read()"
|
|
145
|
+
),
|
|
146
|
+
"raise_statement": (
|
|
147
|
+
"The 'raise' statement is not supported directly in workflow code.\n"
|
|
148
|
+
"Use an @action that raises exceptions, or return error values."
|
|
149
|
+
),
|
|
150
|
+
"assert_statement": (
|
|
151
|
+
"Assert statements are not supported in workflow code.\n"
|
|
152
|
+
"Use an @action for validation, or use if statements with explicit error handling."
|
|
153
|
+
),
|
|
154
|
+
"lambda": (
|
|
155
|
+
"Lambda expressions are not supported in workflow code.\n"
|
|
156
|
+
"Use an @action to define the function logic."
|
|
157
|
+
),
|
|
158
|
+
"list_comprehension": (
|
|
159
|
+
"List comprehensions are only supported when assigned directly to a variable "
|
|
160
|
+
"or inside asyncio.gather(*[...]).\n"
|
|
161
|
+
"For other cases, use a for loop or an @action."
|
|
162
|
+
),
|
|
163
|
+
"dict_comprehension": (
|
|
164
|
+
"Dict comprehensions are only supported when assigned directly to a variable.\n"
|
|
165
|
+
"For other cases, use a for loop or an @action:\n\n"
|
|
166
|
+
" @action\n"
|
|
167
|
+
" async def build_dict(items: list) -> dict:\n"
|
|
168
|
+
" return {k: v for k, v in items}"
|
|
169
|
+
),
|
|
170
|
+
"set_comprehension": (
|
|
171
|
+
"Set comprehensions are not supported in workflow code.\nUse an @action to build sets."
|
|
172
|
+
),
|
|
173
|
+
"generator": (
|
|
174
|
+
"Generator expressions are not supported in workflow code.\n"
|
|
175
|
+
"Use a list or an @action instead."
|
|
176
|
+
),
|
|
177
|
+
"walrus": (
|
|
178
|
+
"The walrus operator (:=) is not supported in workflow code.\n"
|
|
179
|
+
"Use separate assignment statements instead."
|
|
180
|
+
),
|
|
181
|
+
"match": (
|
|
182
|
+
"Match statements are not supported in workflow code.\nUse if/elif/else chains instead."
|
|
183
|
+
),
|
|
184
|
+
"gather_variable_spread": (
|
|
185
|
+
"Spreading a variable in asyncio.gather() is not supported because it requires "
|
|
186
|
+
"data flow analysis to determine the contents.\n"
|
|
187
|
+
"Use a list comprehension directly in gather:\n\n"
|
|
188
|
+
" # Instead of:\n"
|
|
189
|
+
" tasks = []\n"
|
|
190
|
+
" for i in range(count):\n"
|
|
191
|
+
" tasks.append(process(value=i))\n"
|
|
192
|
+
" results = await asyncio.gather(*tasks)\n\n"
|
|
193
|
+
" # Use:\n"
|
|
194
|
+
" results = await asyncio.gather(*[process(value=i) for i in range(count)])"
|
|
195
|
+
),
|
|
196
|
+
"for_loop_append_pattern": (
|
|
197
|
+
"Building a task list in a for loop then spreading in asyncio.gather() is not "
|
|
198
|
+
"supported.\n"
|
|
199
|
+
"Use a list comprehension directly in gather:\n\n"
|
|
200
|
+
" # Instead of:\n"
|
|
201
|
+
" tasks = []\n"
|
|
202
|
+
" for i in range(count):\n"
|
|
203
|
+
" tasks.append(process(value=i))\n"
|
|
204
|
+
" results = await asyncio.gather(*tasks)\n\n"
|
|
205
|
+
" # Use:\n"
|
|
206
|
+
" results = await asyncio.gather(*[process(value=i) for i in range(count)])"
|
|
207
|
+
),
|
|
208
|
+
"global_statement": (
|
|
209
|
+
"Global statements are not supported in workflow code.\n"
|
|
210
|
+
"Use workflow state or pass values explicitly."
|
|
211
|
+
),
|
|
212
|
+
"nonlocal_statement": (
|
|
213
|
+
"Nonlocal statements are not supported in workflow code.\n"
|
|
214
|
+
"Use explicit parameter passing instead."
|
|
215
|
+
),
|
|
216
|
+
"import_statement": (
|
|
217
|
+
"Import statements inside workflow run() are not supported.\n"
|
|
218
|
+
"Place imports at the module level."
|
|
219
|
+
),
|
|
220
|
+
"class_def": (
|
|
221
|
+
"Class definitions inside workflow run() are not supported.\n"
|
|
222
|
+
"Define classes at the module level."
|
|
223
|
+
),
|
|
224
|
+
"function_def": (
|
|
225
|
+
"Nested function definitions inside workflow run() are not supported.\n"
|
|
226
|
+
"Define functions at the module level or use @action."
|
|
227
|
+
),
|
|
228
|
+
"yield_statement": (
|
|
229
|
+
"Yield statements are not supported in workflow code.\n"
|
|
230
|
+
"Workflows must return a complete result, not generate values incrementally."
|
|
231
|
+
),
|
|
232
|
+
"continue_statement": (
|
|
233
|
+
"Continue statements are not supported in workflow code.\n"
|
|
234
|
+
"Restructure your loop using if/else to skip iterations."
|
|
235
|
+
),
|
|
236
|
+
"unsupported_statement": (
|
|
237
|
+
"This statement type is not supported in workflow code.\n"
|
|
238
|
+
"Move the logic into an @action or rewrite using supported statements."
|
|
239
|
+
),
|
|
240
|
+
"unsupported_expression": (
|
|
241
|
+
"This expression type is not supported in workflow code.\n"
|
|
242
|
+
"Move the logic into an @action or rewrite using supported expressions."
|
|
243
|
+
),
|
|
244
|
+
"unsupported_literal": (
|
|
245
|
+
"This literal type is not supported in workflow code.\n"
|
|
246
|
+
"Convert the value to a supported literal type inside an @action."
|
|
247
|
+
),
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
GLOBAL_FUNCTIONS = {
|
|
251
|
+
"enumerate": ir.GlobalFunction.GLOBAL_FUNCTION_ENUMERATE,
|
|
252
|
+
"len": ir.GlobalFunction.GLOBAL_FUNCTION_LEN,
|
|
253
|
+
"range": ir.GlobalFunction.GLOBAL_FUNCTION_RANGE,
|
|
254
|
+
}
|
|
255
|
+
ALLOWED_SYNC_FUNCTIONS = set(GLOBAL_FUNCTIONS)
|
|
256
|
+
|
|
257
|
+
_CURRENT_ACTION_NAMES: set[str] = set()
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
if TYPE_CHECKING:
|
|
261
|
+
from .workflow import Workflow
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
@dataclass
|
|
265
|
+
class ActionDefinition:
|
|
266
|
+
"""Definition of an action function."""
|
|
267
|
+
|
|
268
|
+
action_name: str
|
|
269
|
+
module_name: Optional[str]
|
|
270
|
+
signature: inspect.Signature
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
@dataclass
|
|
274
|
+
class FunctionParameter:
|
|
275
|
+
"""A function parameter with optional default value."""
|
|
276
|
+
|
|
277
|
+
name: str
|
|
278
|
+
has_default: bool
|
|
279
|
+
default_value: Any = None # The actual Python value if has_default is True
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
@dataclass
|
|
283
|
+
class FunctionSignatureInfo:
|
|
284
|
+
"""Signature info for a workflow function, used to fill in default arguments."""
|
|
285
|
+
|
|
286
|
+
parameters: List[FunctionParameter]
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
@dataclass
|
|
290
|
+
class ModuleContext:
|
|
291
|
+
"""Cached IRBuilder context derived from a module."""
|
|
292
|
+
|
|
293
|
+
action_defs: Dict[str, ActionDefinition]
|
|
294
|
+
imported_names: Dict[str, "ImportedName"]
|
|
295
|
+
module_functions: Set[str]
|
|
296
|
+
model_defs: Dict[str, "ModelDefinition"]
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
@dataclass
|
|
300
|
+
class TransformContext:
|
|
301
|
+
"""Context for IR transformations."""
|
|
302
|
+
|
|
303
|
+
# Counter for generating unique function names
|
|
304
|
+
implicit_fn_counter: int = 0
|
|
305
|
+
# Implicit functions generated during transformation
|
|
306
|
+
implicit_functions: List[ir.FunctionDef] = None # type: ignore
|
|
307
|
+
# Function signatures for workflow methods (used to fill in default arguments)
|
|
308
|
+
function_signatures: Dict[str, FunctionSignatureInfo] = None # type: ignore
|
|
309
|
+
|
|
310
|
+
def __post_init__(self) -> None:
|
|
311
|
+
if self.implicit_functions is None:
|
|
312
|
+
self.implicit_functions = []
|
|
313
|
+
if self.function_signatures is None:
|
|
314
|
+
self.function_signatures = {}
|
|
315
|
+
|
|
316
|
+
def next_implicit_fn_name(self, prefix: str = "implicit") -> str:
|
|
317
|
+
"""Generate a unique implicit function name."""
|
|
318
|
+
self.implicit_fn_counter += 1
|
|
319
|
+
return f"__{prefix}_{self.implicit_fn_counter}__"
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def build_workflow_ir(workflow_cls: type["Workflow"]) -> ir.Program:
|
|
323
|
+
"""Build an IR Program from a workflow class.
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
workflow_cls: The workflow class to convert.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
An IR Program proto message.
|
|
330
|
+
"""
|
|
331
|
+
original_run = getattr(workflow_cls, "__workflow_run_impl__", None)
|
|
332
|
+
if original_run is None:
|
|
333
|
+
original_run = workflow_cls.__dict__.get("run")
|
|
334
|
+
if original_run is None:
|
|
335
|
+
raise ValueError(f"workflow {workflow_cls!r} missing run implementation")
|
|
336
|
+
|
|
337
|
+
module = inspect.getmodule(original_run)
|
|
338
|
+
if module is None:
|
|
339
|
+
raise ValueError(f"unable to locate module for workflow {workflow_cls!r}")
|
|
340
|
+
|
|
341
|
+
module_contexts: Dict[str, ModuleContext] = {}
|
|
342
|
+
|
|
343
|
+
def get_module_context(target_module: Any) -> ModuleContext:
|
|
344
|
+
module_name = target_module.__name__
|
|
345
|
+
if module_name not in module_contexts:
|
|
346
|
+
module_contexts[module_name] = ModuleContext(
|
|
347
|
+
action_defs=_discover_action_names(target_module),
|
|
348
|
+
imported_names=_discover_module_imports(target_module),
|
|
349
|
+
module_functions=_discover_module_functions(target_module),
|
|
350
|
+
model_defs=_discover_model_definitions(target_module),
|
|
351
|
+
)
|
|
352
|
+
return module_contexts[module_name]
|
|
353
|
+
|
|
354
|
+
# Build the IR with transformation context
|
|
355
|
+
ctx = TransformContext()
|
|
356
|
+
program = ir.Program()
|
|
357
|
+
function_defs: Dict[str, ir.FunctionDef] = {}
|
|
358
|
+
|
|
359
|
+
# Extract instance attributes from __init__ for policy resolution
|
|
360
|
+
instance_attrs = _extract_instance_attrs(workflow_cls)
|
|
361
|
+
|
|
362
|
+
def parse_function(fn: Any) -> tuple[ast.AST, Optional[str], int]:
|
|
363
|
+
source_lines, start_line = inspect.getsourcelines(fn)
|
|
364
|
+
function_source = textwrap.dedent("".join(source_lines))
|
|
365
|
+
filename = inspect.getsourcefile(fn)
|
|
366
|
+
if filename is None:
|
|
367
|
+
filename = inspect.getfile(fn)
|
|
368
|
+
return ast.parse(function_source, filename=filename or "<unknown>"), filename, start_line
|
|
369
|
+
|
|
370
|
+
def _with_source_location(
|
|
371
|
+
err: UnsupportedPatternError,
|
|
372
|
+
filename: Optional[str],
|
|
373
|
+
start_line: int,
|
|
374
|
+
) -> UnsupportedPatternError:
|
|
375
|
+
line = err.line
|
|
376
|
+
col = err.col
|
|
377
|
+
if line is not None:
|
|
378
|
+
line = start_line + line - 1
|
|
379
|
+
if col is not None:
|
|
380
|
+
col = col + 1
|
|
381
|
+
return UnsupportedPatternError(
|
|
382
|
+
err.message,
|
|
383
|
+
err.recommendation,
|
|
384
|
+
line=line,
|
|
385
|
+
col=col,
|
|
386
|
+
filename=filename,
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
def add_function_def(
|
|
390
|
+
fn: Any,
|
|
391
|
+
fn_tree: ast.AST,
|
|
392
|
+
filename: Optional[str],
|
|
393
|
+
start_line: int,
|
|
394
|
+
override_name: Optional[str] = None,
|
|
395
|
+
) -> None:
|
|
396
|
+
global _CURRENT_ACTION_NAMES
|
|
397
|
+
fn_module = inspect.getmodule(fn)
|
|
398
|
+
if fn_module is None:
|
|
399
|
+
raise ValueError(f"unable to locate module for function {fn!r}")
|
|
400
|
+
|
|
401
|
+
ctx_data = get_module_context(fn_module)
|
|
402
|
+
_CURRENT_ACTION_NAMES = set(ctx_data.action_defs.keys())
|
|
403
|
+
builder = IRBuilder(
|
|
404
|
+
ctx_data.action_defs,
|
|
405
|
+
ctx,
|
|
406
|
+
ctx_data.imported_names,
|
|
407
|
+
ctx_data.module_functions,
|
|
408
|
+
ctx_data.model_defs,
|
|
409
|
+
fn_module.__dict__,
|
|
410
|
+
instance_attrs,
|
|
411
|
+
)
|
|
412
|
+
try:
|
|
413
|
+
builder.visit(fn_tree)
|
|
414
|
+
except UnsupportedPatternError as err:
|
|
415
|
+
raise _with_source_location(err, filename, start_line) from err
|
|
416
|
+
if builder.function_def:
|
|
417
|
+
if override_name:
|
|
418
|
+
builder.function_def.name = override_name
|
|
419
|
+
function_defs[builder.function_def.name] = builder.function_def
|
|
420
|
+
|
|
421
|
+
# Discover all reachable helper methods first so we can pre-collect their signatures.
|
|
422
|
+
# This is needed because when we process a function that calls another function,
|
|
423
|
+
# we need to know the callee's signature to fill in default arguments.
|
|
424
|
+
run_tree, run_filename, run_start_line = parse_function(original_run)
|
|
425
|
+
|
|
426
|
+
# Collect all reachable methods and their trees
|
|
427
|
+
methods_to_process: List[tuple[Any, ast.AST, Optional[str], int, Optional[str]]] = [
|
|
428
|
+
(original_run, run_tree, run_filename, run_start_line, "main")
|
|
429
|
+
]
|
|
430
|
+
pending = list(_collect_self_method_calls(run_tree))
|
|
431
|
+
visited: Set[str] = set()
|
|
432
|
+
skip_methods = {"run_action"}
|
|
433
|
+
|
|
434
|
+
while pending:
|
|
435
|
+
method_name = pending.pop()
|
|
436
|
+
if method_name in visited or method_name == "run" or method_name in skip_methods:
|
|
437
|
+
continue
|
|
438
|
+
visited.add(method_name)
|
|
439
|
+
|
|
440
|
+
method = _find_workflow_method(workflow_cls, method_name)
|
|
441
|
+
if method is None:
|
|
442
|
+
continue
|
|
443
|
+
|
|
444
|
+
method_tree, method_filename, method_start_line = parse_function(method)
|
|
445
|
+
methods_to_process.append((method, method_tree, method_filename, method_start_line, None))
|
|
446
|
+
pending.extend(_collect_self_method_calls(method_tree))
|
|
447
|
+
|
|
448
|
+
# Pre-collect signatures for all methods before processing IR.
|
|
449
|
+
# This ensures that when we encounter a function call, we can fill in default args.
|
|
450
|
+
for fn, _fn_tree, _filename, _start_line, override_name in methods_to_process:
|
|
451
|
+
fn_name = override_name if override_name else fn.__name__
|
|
452
|
+
sig = inspect.signature(fn)
|
|
453
|
+
params: List[FunctionParameter] = []
|
|
454
|
+
for param_name, param in sig.parameters.items():
|
|
455
|
+
if param_name == "self":
|
|
456
|
+
continue
|
|
457
|
+
has_default = param.default is not inspect.Parameter.empty
|
|
458
|
+
default_value = param.default if has_default else None
|
|
459
|
+
params.append(
|
|
460
|
+
FunctionParameter(
|
|
461
|
+
name=param_name,
|
|
462
|
+
has_default=has_default,
|
|
463
|
+
default_value=default_value,
|
|
464
|
+
)
|
|
465
|
+
)
|
|
466
|
+
ctx.function_signatures[fn_name] = FunctionSignatureInfo(parameters=params)
|
|
467
|
+
|
|
468
|
+
# Now process all functions with signatures already available
|
|
469
|
+
for fn, fn_tree, filename, start_line, override_name in methods_to_process:
|
|
470
|
+
add_function_def(fn, fn_tree, filename, start_line, override_name)
|
|
471
|
+
|
|
472
|
+
# Add implicit functions first (they may be called by the main function)
|
|
473
|
+
for implicit_fn in ctx.implicit_functions:
|
|
474
|
+
program.functions.append(implicit_fn)
|
|
475
|
+
|
|
476
|
+
# Add all function definitions (run + reachable helper methods)
|
|
477
|
+
for fn_def in function_defs.values():
|
|
478
|
+
program.functions.append(fn_def)
|
|
479
|
+
|
|
480
|
+
global _CURRENT_ACTION_NAMES
|
|
481
|
+
_CURRENT_ACTION_NAMES = set()
|
|
482
|
+
|
|
483
|
+
return program
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def _collect_self_method_calls(tree: ast.AST) -> Set[str]:
|
|
487
|
+
"""Collect self.method(...) call names from a parsed function AST."""
|
|
488
|
+
calls: Set[str] = set()
|
|
489
|
+
for node in ast.walk(tree):
|
|
490
|
+
if not isinstance(node, ast.Call):
|
|
491
|
+
continue
|
|
492
|
+
func = node.func
|
|
493
|
+
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
|
|
494
|
+
if func.value.id == "self":
|
|
495
|
+
calls.add(func.attr)
|
|
496
|
+
return calls
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def _find_workflow_method(workflow_cls: type["Workflow"], name: str) -> Optional[Any]:
|
|
500
|
+
"""Find a workflow method by name across the class MRO."""
|
|
501
|
+
for base in workflow_cls.__mro__:
|
|
502
|
+
if name not in base.__dict__:
|
|
503
|
+
continue
|
|
504
|
+
value = base.__dict__[name]
|
|
505
|
+
if isinstance(value, staticmethod) or isinstance(value, classmethod):
|
|
506
|
+
return value.__func__
|
|
507
|
+
if inspect.isfunction(value):
|
|
508
|
+
return value
|
|
509
|
+
return None
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def _extract_instance_attrs(workflow_cls: type["Workflow"]) -> Dict[str, ast.expr]:
|
|
513
|
+
"""Extract self.attr = value assignments from the workflow's __init__ method.
|
|
514
|
+
|
|
515
|
+
Parses the __init__ method to find assignments like:
|
|
516
|
+
self.retry_policy = RetryPolicy(attempts=3)
|
|
517
|
+
self.timeout = 30
|
|
518
|
+
|
|
519
|
+
Returns a dict mapping attribute names to their AST value nodes.
|
|
520
|
+
"""
|
|
521
|
+
init_method = _find_workflow_method(workflow_cls, "__init__")
|
|
522
|
+
if init_method is None:
|
|
523
|
+
return {}
|
|
524
|
+
|
|
525
|
+
try:
|
|
526
|
+
source_lines, _ = inspect.getsourcelines(init_method)
|
|
527
|
+
source = textwrap.dedent("".join(source_lines))
|
|
528
|
+
tree = ast.parse(source)
|
|
529
|
+
except (OSError, TypeError, SyntaxError):
|
|
530
|
+
return {}
|
|
531
|
+
|
|
532
|
+
attrs: Dict[str, ast.expr] = {}
|
|
533
|
+
|
|
534
|
+
# Walk the __init__ body looking for self.attr = value assignments
|
|
535
|
+
for node in ast.walk(tree):
|
|
536
|
+
if not isinstance(node, ast.Assign):
|
|
537
|
+
continue
|
|
538
|
+
# Only handle single-target assignments
|
|
539
|
+
if len(node.targets) != 1:
|
|
540
|
+
continue
|
|
541
|
+
target = node.targets[0]
|
|
542
|
+
# Check for self.attr pattern
|
|
543
|
+
if (
|
|
544
|
+
isinstance(target, ast.Attribute)
|
|
545
|
+
and isinstance(target.value, ast.Name)
|
|
546
|
+
and target.value.id == "self"
|
|
547
|
+
):
|
|
548
|
+
attrs[target.attr] = node.value
|
|
549
|
+
|
|
550
|
+
return attrs
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
def _discover_action_names(module: Any) -> Dict[str, ActionDefinition]:
|
|
554
|
+
"""Discover all @action decorated functions in a module."""
|
|
555
|
+
names: Dict[str, ActionDefinition] = {}
|
|
556
|
+
for attr_name in dir(module):
|
|
557
|
+
attr = getattr(module, attr_name)
|
|
558
|
+
action_name = getattr(attr, "__rappel_action_name__", None)
|
|
559
|
+
action_module = getattr(attr, "__rappel_action_module__", None)
|
|
560
|
+
if callable(attr) and action_name:
|
|
561
|
+
signature = inspect.signature(attr)
|
|
562
|
+
names[attr_name] = ActionDefinition(
|
|
563
|
+
action_name=action_name,
|
|
564
|
+
module_name=action_module or module.__name__,
|
|
565
|
+
signature=signature,
|
|
566
|
+
)
|
|
567
|
+
for entry in registry.entries():
|
|
568
|
+
if entry.module != module.__name__:
|
|
569
|
+
continue
|
|
570
|
+
func_name = entry.func.__name__
|
|
571
|
+
if func_name in names:
|
|
572
|
+
continue
|
|
573
|
+
signature = inspect.signature(entry.func)
|
|
574
|
+
names[func_name] = ActionDefinition(
|
|
575
|
+
action_name=entry.name,
|
|
576
|
+
module_name=entry.module,
|
|
577
|
+
signature=signature,
|
|
578
|
+
)
|
|
579
|
+
return names
|
|
580
|
+
|
|
581
|
+
|
|
582
|
+
def _discover_module_functions(module: Any) -> Set[str]:
|
|
583
|
+
"""Discover all async function names defined in a module.
|
|
584
|
+
|
|
585
|
+
This is used to detect when users await functions in the same module
|
|
586
|
+
that are NOT decorated with @action.
|
|
587
|
+
"""
|
|
588
|
+
function_names: Set[str] = set()
|
|
589
|
+
for attr_name in dir(module):
|
|
590
|
+
try:
|
|
591
|
+
attr = getattr(module, attr_name)
|
|
592
|
+
except AttributeError:
|
|
593
|
+
continue
|
|
594
|
+
|
|
595
|
+
# Only include functions defined in THIS module (not imported)
|
|
596
|
+
if not callable(attr):
|
|
597
|
+
continue
|
|
598
|
+
if not inspect.iscoroutinefunction(attr):
|
|
599
|
+
continue
|
|
600
|
+
|
|
601
|
+
# Check if the function is defined in this module
|
|
602
|
+
func_module = getattr(attr, "__module__", None)
|
|
603
|
+
if func_module == module.__name__:
|
|
604
|
+
function_names.add(attr_name)
|
|
605
|
+
|
|
606
|
+
return function_names
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
@dataclass
|
|
610
|
+
class ImportedName:
|
|
611
|
+
"""Tracks an imported name and its source module."""
|
|
612
|
+
|
|
613
|
+
local_name: str # Name used in code (e.g., "sleep")
|
|
614
|
+
module: str # Source module (e.g., "asyncio")
|
|
615
|
+
original_name: str # Original name in source module (e.g., "sleep")
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
@dataclass
|
|
619
|
+
class ModelFieldDefinition:
|
|
620
|
+
"""Definition of a field in a Pydantic model or dataclass."""
|
|
621
|
+
|
|
622
|
+
name: str
|
|
623
|
+
has_default: bool
|
|
624
|
+
default_value: Any = None # Only set if has_default is True
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
@dataclass
|
|
628
|
+
class ModelDefinition:
|
|
629
|
+
"""Definition of a Pydantic model or dataclass that can be used in workflows.
|
|
630
|
+
|
|
631
|
+
These are data classes that can be instantiated in workflow code and will
|
|
632
|
+
be converted to dictionary expressions in the IR.
|
|
633
|
+
"""
|
|
634
|
+
|
|
635
|
+
class_name: str
|
|
636
|
+
fields: Dict[str, ModelFieldDefinition]
|
|
637
|
+
is_pydantic: bool # True for Pydantic models, False for dataclasses
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def _is_simple_pydantic_model(cls: type) -> bool:
|
|
641
|
+
"""Check if a class is a simple Pydantic model without custom validators.
|
|
642
|
+
|
|
643
|
+
A simple Pydantic model:
|
|
644
|
+
- Inherits from pydantic.BaseModel
|
|
645
|
+
- Has no field_validator or model_validator decorators
|
|
646
|
+
- Has no custom __init__ method
|
|
647
|
+
|
|
648
|
+
Returns False if pydantic is not installed or cls is not a Pydantic model.
|
|
649
|
+
"""
|
|
650
|
+
try:
|
|
651
|
+
from pydantic import BaseModel
|
|
652
|
+
except ImportError:
|
|
653
|
+
return False
|
|
654
|
+
|
|
655
|
+
if not isinstance(cls, type) or not issubclass(cls, BaseModel):
|
|
656
|
+
return False
|
|
657
|
+
|
|
658
|
+
# Check for validators - Pydantic v2 uses __pydantic_decorators__
|
|
659
|
+
decorators = getattr(cls, "__pydantic_decorators__", None)
|
|
660
|
+
if decorators is not None:
|
|
661
|
+
# Check for field validators
|
|
662
|
+
if hasattr(decorators, "field_validators") and decorators.field_validators:
|
|
663
|
+
return False
|
|
664
|
+
# Check for model validators
|
|
665
|
+
if hasattr(decorators, "model_validators") and decorators.model_validators:
|
|
666
|
+
return False
|
|
667
|
+
|
|
668
|
+
# Check for custom __init__ (not the one from BaseModel)
|
|
669
|
+
if "__init__" in cls.__dict__:
|
|
670
|
+
return False
|
|
671
|
+
|
|
672
|
+
return True
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
def _is_simple_dataclass(cls: type) -> bool:
|
|
676
|
+
"""Check if a class is a simple dataclass without custom logic.
|
|
677
|
+
|
|
678
|
+
A simple dataclass:
|
|
679
|
+
- Is decorated with @dataclass
|
|
680
|
+
- Has no custom __init__ method (uses the generated one)
|
|
681
|
+
- Has no __post_init__ method
|
|
682
|
+
|
|
683
|
+
Returns False if cls is not a dataclass.
|
|
684
|
+
"""
|
|
685
|
+
import dataclasses
|
|
686
|
+
|
|
687
|
+
if not dataclasses.is_dataclass(cls):
|
|
688
|
+
return False
|
|
689
|
+
|
|
690
|
+
# Check for __post_init__ which could have custom logic
|
|
691
|
+
if hasattr(cls, "__post_init__") and "__post_init__" in cls.__dict__:
|
|
692
|
+
return False
|
|
693
|
+
|
|
694
|
+
# Dataclasses generate __init__, so check if there's a custom one
|
|
695
|
+
# that overrides it (unlikely but possible)
|
|
696
|
+
# The dataclass decorator sets __init__ so we can't easily detect override
|
|
697
|
+
# We'll trust that dataclasses without __post_init__ are simple
|
|
698
|
+
|
|
699
|
+
return True
|
|
700
|
+
|
|
701
|
+
|
|
702
|
+
def _get_pydantic_model_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
|
|
703
|
+
"""Extract field definitions from a Pydantic model."""
|
|
704
|
+
try:
|
|
705
|
+
from pydantic import BaseModel
|
|
706
|
+
from pydantic.fields import FieldInfo
|
|
707
|
+
except ImportError:
|
|
708
|
+
return {}
|
|
709
|
+
|
|
710
|
+
if not issubclass(cls, BaseModel):
|
|
711
|
+
return {}
|
|
712
|
+
|
|
713
|
+
fields: Dict[str, ModelFieldDefinition] = {}
|
|
714
|
+
|
|
715
|
+
# Pydantic v2 uses model_fields
|
|
716
|
+
model_fields = getattr(cls, "model_fields", {})
|
|
717
|
+
for field_name, field_info in model_fields.items():
|
|
718
|
+
has_default = False
|
|
719
|
+
default_value = None
|
|
720
|
+
|
|
721
|
+
if isinstance(field_info, FieldInfo):
|
|
722
|
+
# Check if field has a default value
|
|
723
|
+
# PydanticUndefined means no default
|
|
724
|
+
from pydantic_core import PydanticUndefined
|
|
725
|
+
|
|
726
|
+
if field_info.default is not PydanticUndefined:
|
|
727
|
+
has_default = True
|
|
728
|
+
default_value = field_info.default
|
|
729
|
+
elif field_info.default_factory is not None:
|
|
730
|
+
# We can't serialize factory functions, so treat as no default
|
|
731
|
+
has_default = False
|
|
732
|
+
|
|
733
|
+
fields[field_name] = ModelFieldDefinition(
|
|
734
|
+
name=field_name,
|
|
735
|
+
has_default=has_default,
|
|
736
|
+
default_value=default_value,
|
|
737
|
+
)
|
|
738
|
+
|
|
739
|
+
return fields
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def _get_dataclass_fields(cls: type) -> Dict[str, ModelFieldDefinition]:
|
|
743
|
+
"""Extract field definitions from a dataclass."""
|
|
744
|
+
import dataclasses
|
|
745
|
+
|
|
746
|
+
if not dataclasses.is_dataclass(cls):
|
|
747
|
+
return {}
|
|
748
|
+
|
|
749
|
+
fields: Dict[str, ModelFieldDefinition] = {}
|
|
750
|
+
for field in dataclasses.fields(cls):
|
|
751
|
+
has_default = False
|
|
752
|
+
default_value = None
|
|
753
|
+
|
|
754
|
+
if field.default is not dataclasses.MISSING:
|
|
755
|
+
has_default = True
|
|
756
|
+
default_value = field.default
|
|
757
|
+
elif field.default_factory is not dataclasses.MISSING:
|
|
758
|
+
# We can't serialize factory functions, so treat as no default
|
|
759
|
+
has_default = False
|
|
760
|
+
|
|
761
|
+
fields[field.name] = ModelFieldDefinition(
|
|
762
|
+
name=field.name,
|
|
763
|
+
has_default=has_default,
|
|
764
|
+
default_value=default_value,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
return fields
|
|
768
|
+
|
|
769
|
+
|
|
770
|
+
def _discover_model_definitions(module: Any) -> Dict[str, ModelDefinition]:
|
|
771
|
+
"""Discover all Pydantic models and dataclasses that can be used in workflows.
|
|
772
|
+
|
|
773
|
+
Only discovers "simple" models without custom validators or __post_init__.
|
|
774
|
+
"""
|
|
775
|
+
models: Dict[str, ModelDefinition] = {}
|
|
776
|
+
|
|
777
|
+
for attr_name in dir(module):
|
|
778
|
+
try:
|
|
779
|
+
attr = getattr(module, attr_name)
|
|
780
|
+
except AttributeError:
|
|
781
|
+
continue
|
|
782
|
+
|
|
783
|
+
if not isinstance(attr, type):
|
|
784
|
+
continue
|
|
785
|
+
|
|
786
|
+
# Check if this class is defined in this module or imported
|
|
787
|
+
# We want to include both, as models might be imported
|
|
788
|
+
if _is_simple_pydantic_model(attr):
|
|
789
|
+
fields = _get_pydantic_model_fields(attr)
|
|
790
|
+
models[attr_name] = ModelDefinition(
|
|
791
|
+
class_name=attr_name,
|
|
792
|
+
fields=fields,
|
|
793
|
+
is_pydantic=True,
|
|
794
|
+
)
|
|
795
|
+
elif _is_simple_dataclass(attr):
|
|
796
|
+
fields = _get_dataclass_fields(attr)
|
|
797
|
+
models[attr_name] = ModelDefinition(
|
|
798
|
+
class_name=attr_name,
|
|
799
|
+
fields=fields,
|
|
800
|
+
is_pydantic=False,
|
|
801
|
+
)
|
|
802
|
+
|
|
803
|
+
return models
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
def _discover_module_imports(module: Any) -> Dict[str, ImportedName]:
|
|
807
|
+
"""Discover imports in a module by parsing its source.
|
|
808
|
+
|
|
809
|
+
Tracks imports like:
|
|
810
|
+
- from asyncio import sleep -> {"sleep": ImportedName("sleep", "asyncio", "sleep")}
|
|
811
|
+
- from asyncio import sleep as s -> {"s": ImportedName("s", "asyncio", "sleep")}
|
|
812
|
+
"""
|
|
813
|
+
imported: Dict[str, ImportedName] = {}
|
|
814
|
+
|
|
815
|
+
try:
|
|
816
|
+
source = inspect.getsource(module)
|
|
817
|
+
tree = ast.parse(source)
|
|
818
|
+
except (OSError, TypeError):
|
|
819
|
+
# Can't get source (e.g., built-in module)
|
|
820
|
+
return imported
|
|
821
|
+
|
|
822
|
+
for node in ast.walk(tree):
|
|
823
|
+
if isinstance(node, ast.ImportFrom) and node.module:
|
|
824
|
+
for alias in node.names:
|
|
825
|
+
local_name = alias.asname if alias.asname else alias.name
|
|
826
|
+
imported[local_name] = ImportedName(
|
|
827
|
+
local_name=local_name,
|
|
828
|
+
module=node.module,
|
|
829
|
+
original_name=alias.name,
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
return imported
|
|
833
|
+
|
|
834
|
+
|
|
835
|
+
class IRBuilder(ast.NodeVisitor):
|
|
836
|
+
"""Builds IR from Python AST with deep transformations."""
|
|
837
|
+
|
|
838
|
+
def __init__(
|
|
839
|
+
self,
|
|
840
|
+
action_defs: Dict[str, ActionDefinition],
|
|
841
|
+
ctx: TransformContext,
|
|
842
|
+
imported_names: Optional[Dict[str, ImportedName]] = None,
|
|
843
|
+
module_functions: Optional[Set[str]] = None,
|
|
844
|
+
model_defs: Optional[Dict[str, ModelDefinition]] = None,
|
|
845
|
+
module_globals: Optional[Mapping[str, Any]] = None,
|
|
846
|
+
instance_attrs: Optional[Dict[str, ast.expr]] = None,
|
|
847
|
+
):
|
|
848
|
+
self._action_defs = action_defs
|
|
849
|
+
self._ctx = ctx
|
|
850
|
+
self._imported_names = imported_names or {}
|
|
851
|
+
self._module_functions = module_functions or set()
|
|
852
|
+
self._model_defs = model_defs or {}
|
|
853
|
+
self._module_globals = module_globals or {}
|
|
854
|
+
self._instance_attrs = instance_attrs or {}
|
|
855
|
+
self.function_def: Optional[ir.FunctionDef] = None
|
|
856
|
+
self._statements: List[ir.Statement] = []
|
|
857
|
+
|
|
858
|
+
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
|
|
859
|
+
"""Visit a function definition (the workflow's run method)."""
|
|
860
|
+
inputs = self._collect_function_inputs(node)
|
|
861
|
+
|
|
862
|
+
# Create the function definition
|
|
863
|
+
self.function_def = ir.FunctionDef(
|
|
864
|
+
name=node.name,
|
|
865
|
+
io=ir.IoDecl(inputs=inputs, outputs=[]),
|
|
866
|
+
span=_make_span(node),
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
# Visit the body - _visit_statement now returns a list
|
|
870
|
+
self._statements = []
|
|
871
|
+
for stmt in node.body:
|
|
872
|
+
ir_stmts = self._visit_statement(stmt)
|
|
873
|
+
self._statements.extend(ir_stmts)
|
|
874
|
+
|
|
875
|
+
# Set the body
|
|
876
|
+
self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
|
|
877
|
+
|
|
878
|
+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Any:
|
|
879
|
+
"""Visit an async function definition (the workflow's run method)."""
|
|
880
|
+
# Handle async the same way as sync for IR building
|
|
881
|
+
inputs = self._collect_function_inputs(node)
|
|
882
|
+
|
|
883
|
+
self.function_def = ir.FunctionDef(
|
|
884
|
+
name=node.name,
|
|
885
|
+
io=ir.IoDecl(inputs=inputs, outputs=[]),
|
|
886
|
+
span=_make_span(node),
|
|
887
|
+
)
|
|
888
|
+
|
|
889
|
+
self._statements = []
|
|
890
|
+
for stmt in node.body:
|
|
891
|
+
ir_stmts = self._visit_statement(stmt)
|
|
892
|
+
self._statements.extend(ir_stmts)
|
|
893
|
+
|
|
894
|
+
self.function_def.body.CopyFrom(ir.Block(statements=self._statements))
|
|
895
|
+
|
|
896
|
+
def _visit_statement(self, node: ast.stmt) -> List[ir.Statement]:
|
|
897
|
+
"""Convert a Python statement to IR Statement(s).
|
|
898
|
+
|
|
899
|
+
Returns a list because some transformations (like try block hoisting)
|
|
900
|
+
may expand a single Python statement into multiple IR statements.
|
|
901
|
+
|
|
902
|
+
Raises UnsupportedPatternError for unsupported statement types.
|
|
903
|
+
"""
|
|
904
|
+
if isinstance(node, ast.Assign):
|
|
905
|
+
dict_expanded = self._expand_dict_comprehension_assignment(node)
|
|
906
|
+
if dict_expanded is not None:
|
|
907
|
+
return dict_expanded
|
|
908
|
+
expanded = self._expand_list_comprehension_assignment(node)
|
|
909
|
+
if expanded is not None:
|
|
910
|
+
return expanded
|
|
911
|
+
result = self._visit_assign(node)
|
|
912
|
+
return [result] if result else []
|
|
913
|
+
elif isinstance(node, ast.AnnAssign):
|
|
914
|
+
result = self._visit_ann_assign(node)
|
|
915
|
+
return [result] if result else []
|
|
916
|
+
elif isinstance(node, ast.Expr):
|
|
917
|
+
result = self._visit_expr_stmt(node)
|
|
918
|
+
return [result] if result else []
|
|
919
|
+
elif isinstance(node, ast.For):
|
|
920
|
+
return self._visit_for(node)
|
|
921
|
+
elif isinstance(node, ast.If):
|
|
922
|
+
return self._visit_if(node)
|
|
923
|
+
elif isinstance(node, ast.Try):
|
|
924
|
+
return self._visit_try(node)
|
|
925
|
+
elif isinstance(node, ast.Return):
|
|
926
|
+
return self._visit_return(node)
|
|
927
|
+
elif isinstance(node, ast.AugAssign):
|
|
928
|
+
return self._visit_aug_assign(node)
|
|
929
|
+
elif isinstance(node, ast.Pass):
|
|
930
|
+
# Pass statements are fine, they just don't produce IR
|
|
931
|
+
return []
|
|
932
|
+
elif isinstance(node, ast.Break):
|
|
933
|
+
return self._visit_break(node)
|
|
934
|
+
elif isinstance(node, ast.Continue):
|
|
935
|
+
return self._visit_continue(node)
|
|
936
|
+
|
|
937
|
+
# Check for unsupported statement types - this MUST raise for any
|
|
938
|
+
# unhandled statement to avoid silently dropping code
|
|
939
|
+
self._check_unsupported_statement(node)
|
|
940
|
+
|
|
941
|
+
def _check_unsupported_statement(self, node: ast.stmt) -> NoReturn:
|
|
942
|
+
"""Check for unsupported statement types and raise descriptive errors.
|
|
943
|
+
|
|
944
|
+
This function ALWAYS raises an exception - it never returns normally.
|
|
945
|
+
Any statement type that reaches this function is either explicitly
|
|
946
|
+
unsupported (with a specific error message) or unhandled (with a
|
|
947
|
+
generic catch-all error). This ensures we never silently drop code.
|
|
948
|
+
"""
|
|
949
|
+
line = getattr(node, "lineno", None)
|
|
950
|
+
col = getattr(node, "col_offset", None)
|
|
951
|
+
|
|
952
|
+
if isinstance(node, ast.While):
|
|
953
|
+
raise UnsupportedPatternError(
|
|
954
|
+
"While loops are not supported",
|
|
955
|
+
RECOMMENDATIONS["while_loop"],
|
|
956
|
+
line=line,
|
|
957
|
+
col=col,
|
|
958
|
+
)
|
|
959
|
+
elif isinstance(node, (ast.With, ast.AsyncWith)):
|
|
960
|
+
raise UnsupportedPatternError(
|
|
961
|
+
"Context managers (with statements) are not supported",
|
|
962
|
+
RECOMMENDATIONS["with_statement"],
|
|
963
|
+
line=line,
|
|
964
|
+
col=col,
|
|
965
|
+
)
|
|
966
|
+
elif isinstance(node, ast.Raise):
|
|
967
|
+
raise UnsupportedPatternError(
|
|
968
|
+
"The 'raise' statement is not supported",
|
|
969
|
+
RECOMMENDATIONS["raise_statement"],
|
|
970
|
+
line=line,
|
|
971
|
+
col=col,
|
|
972
|
+
)
|
|
973
|
+
elif isinstance(node, ast.Assert):
|
|
974
|
+
raise UnsupportedPatternError(
|
|
975
|
+
"Assert statements are not supported",
|
|
976
|
+
RECOMMENDATIONS["assert_statement"],
|
|
977
|
+
line=line,
|
|
978
|
+
col=col,
|
|
979
|
+
)
|
|
980
|
+
elif isinstance(node, ast.Delete):
|
|
981
|
+
raise UnsupportedPatternError(
|
|
982
|
+
"The 'del' statement is not supported",
|
|
983
|
+
RECOMMENDATIONS["delete"],
|
|
984
|
+
line=line,
|
|
985
|
+
col=col,
|
|
986
|
+
)
|
|
987
|
+
elif isinstance(node, ast.Global):
|
|
988
|
+
raise UnsupportedPatternError(
|
|
989
|
+
"Global statements are not supported",
|
|
990
|
+
RECOMMENDATIONS["global_statement"],
|
|
991
|
+
line=line,
|
|
992
|
+
col=col,
|
|
993
|
+
)
|
|
994
|
+
elif isinstance(node, ast.Nonlocal):
|
|
995
|
+
raise UnsupportedPatternError(
|
|
996
|
+
"Nonlocal statements are not supported",
|
|
997
|
+
RECOMMENDATIONS["nonlocal_statement"],
|
|
998
|
+
line=line,
|
|
999
|
+
col=col,
|
|
1000
|
+
)
|
|
1001
|
+
elif isinstance(node, (ast.Import, ast.ImportFrom)):
|
|
1002
|
+
raise UnsupportedPatternError(
|
|
1003
|
+
"Import statements inside workflow run() are not supported",
|
|
1004
|
+
RECOMMENDATIONS["import_statement"],
|
|
1005
|
+
line=line,
|
|
1006
|
+
col=col,
|
|
1007
|
+
)
|
|
1008
|
+
elif isinstance(node, ast.ClassDef):
|
|
1009
|
+
raise UnsupportedPatternError(
|
|
1010
|
+
"Class definitions inside workflow run() are not supported",
|
|
1011
|
+
RECOMMENDATIONS["class_def"],
|
|
1012
|
+
line=line,
|
|
1013
|
+
col=col,
|
|
1014
|
+
)
|
|
1015
|
+
elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
|
1016
|
+
raise UnsupportedPatternError(
|
|
1017
|
+
"Nested function definitions are not supported",
|
|
1018
|
+
RECOMMENDATIONS["function_def"],
|
|
1019
|
+
line=line,
|
|
1020
|
+
col=col,
|
|
1021
|
+
)
|
|
1022
|
+
elif hasattr(ast, "Match") and isinstance(node, ast.Match):
|
|
1023
|
+
raise UnsupportedPatternError(
|
|
1024
|
+
"Match statements are not supported",
|
|
1025
|
+
RECOMMENDATIONS["match"],
|
|
1026
|
+
line=line,
|
|
1027
|
+
col=col,
|
|
1028
|
+
)
|
|
1029
|
+
else:
|
|
1030
|
+
# Catch-all for any unhandled statement types.
|
|
1031
|
+
# This is critical to avoid silently dropping code.
|
|
1032
|
+
stmt_type = type(node).__name__
|
|
1033
|
+
raise UnsupportedPatternError(
|
|
1034
|
+
f"Unhandled statement type: {stmt_type}",
|
|
1035
|
+
RECOMMENDATIONS["unsupported_statement"],
|
|
1036
|
+
line=line,
|
|
1037
|
+
col=col,
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
def _expand_list_comprehension_assignment(
|
|
1041
|
+
self, node: ast.Assign
|
|
1042
|
+
) -> Optional[List[ir.Statement]]:
|
|
1043
|
+
"""Expand a list comprehension assignment into loop-based statements.
|
|
1044
|
+
|
|
1045
|
+
Example:
|
|
1046
|
+
active_users = [user for user in users if user.active]
|
|
1047
|
+
|
|
1048
|
+
Becomes:
|
|
1049
|
+
active_users = []
|
|
1050
|
+
for user in users:
|
|
1051
|
+
if user.active:
|
|
1052
|
+
active_users = active_users + [user]
|
|
1053
|
+
"""
|
|
1054
|
+
if not isinstance(node.value, ast.ListComp):
|
|
1055
|
+
return None
|
|
1056
|
+
|
|
1057
|
+
if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
|
|
1058
|
+
line = getattr(node, "lineno", None)
|
|
1059
|
+
col = getattr(node, "col_offset", None)
|
|
1060
|
+
raise UnsupportedPatternError(
|
|
1061
|
+
"List comprehension assignments must target a single variable",
|
|
1062
|
+
"Assign the comprehension to a simple variable like `results = [x for x in items]`",
|
|
1063
|
+
line=line,
|
|
1064
|
+
col=col,
|
|
1065
|
+
)
|
|
1066
|
+
|
|
1067
|
+
listcomp = node.value
|
|
1068
|
+
if len(listcomp.generators) != 1:
|
|
1069
|
+
line = getattr(listcomp, "lineno", None)
|
|
1070
|
+
col = getattr(listcomp, "col_offset", None)
|
|
1071
|
+
raise UnsupportedPatternError(
|
|
1072
|
+
"List comprehensions with multiple generators are not supported",
|
|
1073
|
+
"Use nested for loops instead of combining multiple generators in one comprehension",
|
|
1074
|
+
line=line,
|
|
1075
|
+
col=col,
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
gen = listcomp.generators[0]
|
|
1079
|
+
if gen.is_async:
|
|
1080
|
+
line = getattr(listcomp, "lineno", None)
|
|
1081
|
+
col = getattr(listcomp, "col_offset", None)
|
|
1082
|
+
raise UnsupportedPatternError(
|
|
1083
|
+
"Async list comprehensions are not supported",
|
|
1084
|
+
"Rewrite using an explicit async for loop",
|
|
1085
|
+
line=line,
|
|
1086
|
+
col=col,
|
|
1087
|
+
)
|
|
1088
|
+
|
|
1089
|
+
target_name = node.targets[0].id
|
|
1090
|
+
|
|
1091
|
+
# Initialize the accumulator list: active_users = []
|
|
1092
|
+
init_assign_ast = ast.Assign(
|
|
1093
|
+
targets=[ast.Name(id=target_name, ctx=ast.Store())],
|
|
1094
|
+
value=ast.List(elts=[], ctx=ast.Load()),
|
|
1095
|
+
type_comment=None,
|
|
1096
|
+
)
|
|
1097
|
+
ast.copy_location(init_assign_ast, node)
|
|
1098
|
+
ast.fix_missing_locations(init_assign_ast)
|
|
1099
|
+
|
|
1100
|
+
def _make_append_assignment(value_expr: ast.expr) -> ast.Assign:
|
|
1101
|
+
append_assign = ast.Assign(
|
|
1102
|
+
targets=[ast.Name(id=target_name, ctx=ast.Store())],
|
|
1103
|
+
value=ast.BinOp(
|
|
1104
|
+
left=ast.Name(id=target_name, ctx=ast.Load()),
|
|
1105
|
+
op=ast.Add(),
|
|
1106
|
+
right=ast.List(elts=[copy.deepcopy(value_expr)], ctx=ast.Load()),
|
|
1107
|
+
),
|
|
1108
|
+
type_comment=None,
|
|
1109
|
+
)
|
|
1110
|
+
ast.copy_location(append_assign, node.value)
|
|
1111
|
+
ast.fix_missing_locations(append_assign)
|
|
1112
|
+
return append_assign
|
|
1113
|
+
|
|
1114
|
+
append_statements: List[ast.stmt] = []
|
|
1115
|
+
if isinstance(listcomp.elt, ast.IfExp):
|
|
1116
|
+
then_assign = _make_append_assignment(listcomp.elt.body)
|
|
1117
|
+
else_assign = _make_append_assignment(listcomp.elt.orelse)
|
|
1118
|
+
branch_if = ast.If(
|
|
1119
|
+
test=copy.deepcopy(listcomp.elt.test),
|
|
1120
|
+
body=[then_assign],
|
|
1121
|
+
orelse=[else_assign],
|
|
1122
|
+
)
|
|
1123
|
+
ast.copy_location(branch_if, listcomp.elt)
|
|
1124
|
+
ast.fix_missing_locations(branch_if)
|
|
1125
|
+
append_statements.append(branch_if)
|
|
1126
|
+
else:
|
|
1127
|
+
append_statements.append(_make_append_assignment(listcomp.elt))
|
|
1128
|
+
|
|
1129
|
+
loop_body: List[ast.stmt] = append_statements
|
|
1130
|
+
if gen.ifs:
|
|
1131
|
+
condition: ast.expr
|
|
1132
|
+
if len(gen.ifs) == 1:
|
|
1133
|
+
condition = copy.deepcopy(gen.ifs[0])
|
|
1134
|
+
else:
|
|
1135
|
+
condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
|
|
1136
|
+
ast.copy_location(condition, gen.ifs[0])
|
|
1137
|
+
if_stmt = ast.If(test=condition, body=append_statements, orelse=[])
|
|
1138
|
+
ast.copy_location(if_stmt, gen.ifs[0])
|
|
1139
|
+
ast.fix_missing_locations(if_stmt)
|
|
1140
|
+
loop_body = [if_stmt]
|
|
1141
|
+
|
|
1142
|
+
loop_ast = ast.For(
|
|
1143
|
+
target=copy.deepcopy(gen.target),
|
|
1144
|
+
iter=copy.deepcopy(gen.iter),
|
|
1145
|
+
body=loop_body,
|
|
1146
|
+
orelse=[],
|
|
1147
|
+
type_comment=None,
|
|
1148
|
+
)
|
|
1149
|
+
ast.copy_location(loop_ast, node)
|
|
1150
|
+
ast.fix_missing_locations(loop_ast)
|
|
1151
|
+
|
|
1152
|
+
statements: List[ir.Statement] = []
|
|
1153
|
+
init_stmt = self._visit_assign(init_assign_ast)
|
|
1154
|
+
if init_stmt:
|
|
1155
|
+
statements.append(init_stmt)
|
|
1156
|
+
statements.extend(self._visit_for(loop_ast))
|
|
1157
|
+
|
|
1158
|
+
return statements
|
|
1159
|
+
|
|
1160
|
+
def _expand_dict_comprehension_assignment(
|
|
1161
|
+
self, node: ast.Assign
|
|
1162
|
+
) -> Optional[List[ir.Statement]]:
|
|
1163
|
+
"""Expand a dict comprehension assignment into loop-based statements."""
|
|
1164
|
+
if not isinstance(node.value, ast.DictComp):
|
|
1165
|
+
return None
|
|
1166
|
+
|
|
1167
|
+
if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name):
|
|
1168
|
+
line = getattr(node, "lineno", None)
|
|
1169
|
+
col = getattr(node, "col_offset", None)
|
|
1170
|
+
raise UnsupportedPatternError(
|
|
1171
|
+
"Dict comprehension assignments must target a single variable",
|
|
1172
|
+
"Assign the comprehension to a simple variable like `result = {k: v for k, v in pairs}`",
|
|
1173
|
+
line=line,
|
|
1174
|
+
col=col,
|
|
1175
|
+
)
|
|
1176
|
+
|
|
1177
|
+
dictcomp = node.value
|
|
1178
|
+
if len(dictcomp.generators) != 1:
|
|
1179
|
+
line = getattr(dictcomp, "lineno", None)
|
|
1180
|
+
col = getattr(dictcomp, "col_offset", None)
|
|
1181
|
+
raise UnsupportedPatternError(
|
|
1182
|
+
"Dict comprehensions with multiple generators are not supported",
|
|
1183
|
+
"Use nested for loops instead of combining multiple generators in one comprehension",
|
|
1184
|
+
line=line,
|
|
1185
|
+
col=col,
|
|
1186
|
+
)
|
|
1187
|
+
|
|
1188
|
+
gen = dictcomp.generators[0]
|
|
1189
|
+
if gen.is_async:
|
|
1190
|
+
line = getattr(dictcomp, "lineno", None)
|
|
1191
|
+
col = getattr(dictcomp, "col_offset", None)
|
|
1192
|
+
raise UnsupportedPatternError(
|
|
1193
|
+
"Async dict comprehensions are not supported",
|
|
1194
|
+
"Rewrite using an explicit async for loop",
|
|
1195
|
+
line=line,
|
|
1196
|
+
col=col,
|
|
1197
|
+
)
|
|
1198
|
+
|
|
1199
|
+
target_name = node.targets[0].id
|
|
1200
|
+
|
|
1201
|
+
# Initialize accumulator: result = {}
|
|
1202
|
+
init_assign_ast = ast.Assign(
|
|
1203
|
+
targets=[ast.Name(id=target_name, ctx=ast.Store())],
|
|
1204
|
+
value=ast.Dict(keys=[], values=[]),
|
|
1205
|
+
type_comment=None,
|
|
1206
|
+
)
|
|
1207
|
+
ast.copy_location(init_assign_ast, node)
|
|
1208
|
+
ast.fix_missing_locations(init_assign_ast)
|
|
1209
|
+
|
|
1210
|
+
# result[key] = value
|
|
1211
|
+
subscript_target = ast.Subscript(
|
|
1212
|
+
value=ast.Name(id=target_name, ctx=ast.Load()),
|
|
1213
|
+
slice=copy.deepcopy(dictcomp.key),
|
|
1214
|
+
ctx=ast.Store(),
|
|
1215
|
+
)
|
|
1216
|
+
append_assign_ast = ast.Assign(
|
|
1217
|
+
targets=[subscript_target],
|
|
1218
|
+
value=copy.deepcopy(dictcomp.value),
|
|
1219
|
+
type_comment=None,
|
|
1220
|
+
)
|
|
1221
|
+
ast.copy_location(append_assign_ast, node.value)
|
|
1222
|
+
ast.fix_missing_locations(append_assign_ast)
|
|
1223
|
+
|
|
1224
|
+
loop_body: List[ast.stmt] = []
|
|
1225
|
+
if gen.ifs:
|
|
1226
|
+
condition: ast.expr
|
|
1227
|
+
if len(gen.ifs) == 1:
|
|
1228
|
+
condition = copy.deepcopy(gen.ifs[0])
|
|
1229
|
+
else:
|
|
1230
|
+
condition = ast.BoolOp(op=ast.And(), values=[copy.deepcopy(iff) for iff in gen.ifs])
|
|
1231
|
+
ast.copy_location(condition, gen.ifs[0])
|
|
1232
|
+
if_stmt = ast.If(test=condition, body=[append_assign_ast], orelse=[])
|
|
1233
|
+
ast.copy_location(if_stmt, gen.ifs[0])
|
|
1234
|
+
ast.fix_missing_locations(if_stmt)
|
|
1235
|
+
loop_body.append(if_stmt)
|
|
1236
|
+
else:
|
|
1237
|
+
loop_body.append(append_assign_ast)
|
|
1238
|
+
|
|
1239
|
+
loop_ast = ast.For(
|
|
1240
|
+
target=copy.deepcopy(gen.target),
|
|
1241
|
+
iter=copy.deepcopy(gen.iter),
|
|
1242
|
+
body=loop_body,
|
|
1243
|
+
orelse=[],
|
|
1244
|
+
type_comment=None,
|
|
1245
|
+
)
|
|
1246
|
+
ast.copy_location(loop_ast, node)
|
|
1247
|
+
ast.fix_missing_locations(loop_ast)
|
|
1248
|
+
|
|
1249
|
+
statements: List[ir.Statement] = []
|
|
1250
|
+
init_stmt = self._visit_assign(init_assign_ast)
|
|
1251
|
+
if init_stmt:
|
|
1252
|
+
statements.append(init_stmt)
|
|
1253
|
+
statements.extend(self._visit_for(loop_ast))
|
|
1254
|
+
|
|
1255
|
+
return statements
|
|
1256
|
+
|
|
1257
|
+
def _visit_ann_assign(self, node: ast.AnnAssign) -> Optional[ir.Statement]:
|
|
1258
|
+
"""Convert annotated assignment to IR when a value is present."""
|
|
1259
|
+
if node.value is None:
|
|
1260
|
+
return None
|
|
1261
|
+
|
|
1262
|
+
assign = ast.Assign(targets=[node.target], value=node.value, type_comment=None)
|
|
1263
|
+
ast.copy_location(assign, node)
|
|
1264
|
+
ast.fix_missing_locations(assign)
|
|
1265
|
+
return self._visit_assign(assign)
|
|
1266
|
+
|
|
1267
|
+
def _visit_assign(self, node: ast.Assign) -> Optional[ir.Statement]:
|
|
1268
|
+
"""Convert assignment to IR.
|
|
1269
|
+
|
|
1270
|
+
All assignments with targets use the Assignment statement type.
|
|
1271
|
+
This provides uniform unpacking support for:
|
|
1272
|
+
- Action calls: a, b = @get_pair()
|
|
1273
|
+
- Parallel blocks: a, b = parallel: @x() @y()
|
|
1274
|
+
- Regular expressions: a, b = some_list
|
|
1275
|
+
|
|
1276
|
+
Raises UnsupportedPatternError for:
|
|
1277
|
+
- Constructor calls: x = MyClass(...)
|
|
1278
|
+
- Non-action await: x = await some_func()
|
|
1279
|
+
"""
|
|
1280
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
1281
|
+
targets = self._get_assign_targets(node.targets)
|
|
1282
|
+
|
|
1283
|
+
# Check for Pydantic model or dataclass constructor calls
|
|
1284
|
+
# These are converted to dict expressions
|
|
1285
|
+
model_name = self._is_model_constructor(node.value)
|
|
1286
|
+
if model_name and isinstance(node.value, ast.Call):
|
|
1287
|
+
value_expr = self._convert_model_constructor_to_dict(node.value, model_name)
|
|
1288
|
+
assign = ir.Assignment(targets=targets, value=value_expr)
|
|
1289
|
+
stmt.assignment.CopyFrom(assign)
|
|
1290
|
+
return stmt
|
|
1291
|
+
|
|
1292
|
+
# Check for constructor calls in assignment (e.g., x = MyModel(...))
|
|
1293
|
+
# This must come AFTER the model constructor check since models are allowed
|
|
1294
|
+
self._check_constructor_in_assignment(node.value)
|
|
1295
|
+
|
|
1296
|
+
# Check for asyncio.gather() - convert to parallel or spread expression
|
|
1297
|
+
if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
|
|
1298
|
+
gather_call = node.value.value
|
|
1299
|
+
if self._is_asyncio_gather_call(gather_call):
|
|
1300
|
+
gather_result = self._convert_asyncio_gather(gather_call)
|
|
1301
|
+
if gather_result is not None:
|
|
1302
|
+
if isinstance(gather_result, ir.ParallelExpr):
|
|
1303
|
+
value = ir.Expr(parallel_expr=gather_result, span=_make_span(node))
|
|
1304
|
+
else:
|
|
1305
|
+
# SpreadExpr
|
|
1306
|
+
value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
|
|
1307
|
+
assign = ir.Assignment(targets=targets, value=value)
|
|
1308
|
+
stmt.assignment.CopyFrom(assign)
|
|
1309
|
+
return stmt
|
|
1310
|
+
|
|
1311
|
+
# Check if this is an action call - wrap in Assignment for uniform unpacking
|
|
1312
|
+
action_call = self._extract_action_call(node.value)
|
|
1313
|
+
if action_call:
|
|
1314
|
+
value = ir.Expr(action_call=action_call, span=_make_span(node))
|
|
1315
|
+
assign = ir.Assignment(targets=targets, value=value)
|
|
1316
|
+
stmt.assignment.CopyFrom(assign)
|
|
1317
|
+
return stmt
|
|
1318
|
+
|
|
1319
|
+
# Regular assignment (variables, literals, expressions)
|
|
1320
|
+
value_expr = self._expr_to_ir_with_model_coercion(node.value)
|
|
1321
|
+
if value_expr:
|
|
1322
|
+
assign = ir.Assignment(targets=targets, value=value_expr)
|
|
1323
|
+
stmt.assignment.CopyFrom(assign)
|
|
1324
|
+
return stmt
|
|
1325
|
+
|
|
1326
|
+
return None
|
|
1327
|
+
|
|
1328
|
+
def _visit_expr_stmt(self, node: ast.Expr) -> Optional[ir.Statement]:
|
|
1329
|
+
"""Convert expression statement to IR (side effect only, no assignment)."""
|
|
1330
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
1331
|
+
|
|
1332
|
+
# Check for asyncio.gather() - convert to parallel block statement (side effect)
|
|
1333
|
+
if isinstance(node.value, ast.Await) and isinstance(node.value.value, ast.Call):
|
|
1334
|
+
gather_call = node.value.value
|
|
1335
|
+
if self._is_asyncio_gather_call(gather_call):
|
|
1336
|
+
gather_result = self._convert_asyncio_gather(gather_call)
|
|
1337
|
+
if gather_result is not None:
|
|
1338
|
+
if isinstance(gather_result, ir.ParallelExpr):
|
|
1339
|
+
# Side effect only - use ParallelBlock statement
|
|
1340
|
+
parallel = ir.ParallelBlock()
|
|
1341
|
+
parallel.calls.extend(gather_result.calls)
|
|
1342
|
+
stmt.parallel_block.CopyFrom(parallel)
|
|
1343
|
+
return stmt
|
|
1344
|
+
else:
|
|
1345
|
+
# SpreadExpr as side effect - wrap in assignment with no targets
|
|
1346
|
+
# This handles: await asyncio.gather(*[action(x) for x in items])
|
|
1347
|
+
value = ir.Expr(spread_expr=gather_result, span=_make_span(node))
|
|
1348
|
+
assign = ir.Assignment(targets=[], value=value)
|
|
1349
|
+
stmt.assignment.CopyFrom(assign)
|
|
1350
|
+
return stmt
|
|
1351
|
+
|
|
1352
|
+
# Check if this is an action call (side effect only)
|
|
1353
|
+
action_call = self._extract_action_call(node.value)
|
|
1354
|
+
if action_call:
|
|
1355
|
+
stmt.action_call.CopyFrom(action_call)
|
|
1356
|
+
return stmt
|
|
1357
|
+
|
|
1358
|
+
# Convert list.append(x) to list = list + [x]
|
|
1359
|
+
# This makes the mutation explicit so data flows correctly through the DAG
|
|
1360
|
+
if isinstance(node.value, ast.Call):
|
|
1361
|
+
call = node.value
|
|
1362
|
+
if (
|
|
1363
|
+
isinstance(call.func, ast.Attribute)
|
|
1364
|
+
and call.func.attr == "append"
|
|
1365
|
+
and isinstance(call.func.value, ast.Name)
|
|
1366
|
+
and len(call.args) == 1
|
|
1367
|
+
):
|
|
1368
|
+
list_name = call.func.value.id
|
|
1369
|
+
append_value = call.args[0]
|
|
1370
|
+
# Create: list = list + [value]
|
|
1371
|
+
list_var = ir.Expr(variable=ir.Variable(name=list_name), span=_make_span(node))
|
|
1372
|
+
value_expr = self._expr_to_ir_with_model_coercion(append_value)
|
|
1373
|
+
if value_expr:
|
|
1374
|
+
# Create [value] as a list literal
|
|
1375
|
+
list_literal = ir.Expr(
|
|
1376
|
+
list=ir.ListExpr(elements=[value_expr]), span=_make_span(node)
|
|
1377
|
+
)
|
|
1378
|
+
# Create list + [value]
|
|
1379
|
+
concat_expr = ir.Expr(
|
|
1380
|
+
binary_op=ir.BinaryOp(
|
|
1381
|
+
op=ir.BinaryOperator.BINARY_OP_ADD, left=list_var, right=list_literal
|
|
1382
|
+
),
|
|
1383
|
+
span=_make_span(node),
|
|
1384
|
+
)
|
|
1385
|
+
assign = ir.Assignment(targets=[list_name], value=concat_expr)
|
|
1386
|
+
stmt.assignment.CopyFrom(assign)
|
|
1387
|
+
return stmt
|
|
1388
|
+
|
|
1389
|
+
# Regular expression
|
|
1390
|
+
expr = self._expr_to_ir_with_model_coercion(node.value)
|
|
1391
|
+
if expr:
|
|
1392
|
+
stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=expr))
|
|
1393
|
+
return stmt
|
|
1394
|
+
|
|
1395
|
+
return None
|
|
1396
|
+
|
|
1397
|
+
def _visit_for(self, node: ast.For) -> List[ir.Statement]:
|
|
1398
|
+
"""Convert for loop to IR.
|
|
1399
|
+
|
|
1400
|
+
The loop body is emitted as a full block so it can contain multiple
|
|
1401
|
+
statements/calls and early `return`.
|
|
1402
|
+
"""
|
|
1403
|
+
# Get loop variables
|
|
1404
|
+
loop_vars: List[str] = []
|
|
1405
|
+
if isinstance(node.target, ast.Name):
|
|
1406
|
+
loop_vars.append(node.target.id)
|
|
1407
|
+
elif isinstance(node.target, ast.Tuple):
|
|
1408
|
+
for elt in node.target.elts:
|
|
1409
|
+
if isinstance(elt, ast.Name):
|
|
1410
|
+
loop_vars.append(elt.id)
|
|
1411
|
+
|
|
1412
|
+
# Get iterable
|
|
1413
|
+
iterable = self._expr_to_ir_with_model_coercion(node.iter)
|
|
1414
|
+
if not iterable:
|
|
1415
|
+
return []
|
|
1416
|
+
|
|
1417
|
+
# Build body statements (recursively transforms nested structures)
|
|
1418
|
+
body_stmts: List[ir.Statement] = []
|
|
1419
|
+
for body_node in node.body:
|
|
1420
|
+
stmts = self._visit_statement(body_node)
|
|
1421
|
+
body_stmts.extend(stmts)
|
|
1422
|
+
|
|
1423
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
1424
|
+
for_loop = ir.ForLoop(
|
|
1425
|
+
loop_vars=loop_vars,
|
|
1426
|
+
iterable=iterable,
|
|
1427
|
+
block_body=ir.Block(statements=body_stmts, span=_make_span(node)),
|
|
1428
|
+
)
|
|
1429
|
+
stmt.for_loop.CopyFrom(for_loop)
|
|
1430
|
+
return [stmt]
|
|
1431
|
+
|
|
1432
|
+
def _detect_accumulator_targets(
|
|
1433
|
+
self, stmts: List[ir.Statement], in_scope_vars: set
|
|
1434
|
+
) -> List[str]:
|
|
1435
|
+
"""Detect out-of-scope variable modifications in for loop body.
|
|
1436
|
+
|
|
1437
|
+
Scans statements for patterns that modify variables defined outside the loop.
|
|
1438
|
+
Returns a list of accumulator variable names that should be set as targets.
|
|
1439
|
+
|
|
1440
|
+
Supported patterns:
|
|
1441
|
+
1. List append: results.append(value) -> "results"
|
|
1442
|
+
2. Dict subscript: result[key] = value -> "result"
|
|
1443
|
+
3. List/set update methods: results.extend(...), results.add(...) -> "results"
|
|
1444
|
+
|
|
1445
|
+
Note: Patterns like `results = results + [x]` and `count = count + 1` create
|
|
1446
|
+
new assignments which are tracked via in_scope_vars and don't need special
|
|
1447
|
+
detection here - they're handled by the regular assignment target logic.
|
|
1448
|
+
"""
|
|
1449
|
+
accumulators: List[str] = []
|
|
1450
|
+
seen: set = set()
|
|
1451
|
+
|
|
1452
|
+
for stmt in stmts:
|
|
1453
|
+
var_name = self._extract_accumulator_from_stmt(stmt, in_scope_vars)
|
|
1454
|
+
if var_name and var_name not in seen:
|
|
1455
|
+
accumulators.append(var_name)
|
|
1456
|
+
seen.add(var_name)
|
|
1457
|
+
|
|
1458
|
+
# Check conditionals for accumulator targets in branch bodies
|
|
1459
|
+
if stmt.HasField("conditional"):
|
|
1460
|
+
cond = stmt.conditional
|
|
1461
|
+
branch_blocks: list[ir.Block] = []
|
|
1462
|
+
if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
|
|
1463
|
+
branch_blocks.append(cond.if_branch.block_body)
|
|
1464
|
+
for branch in cond.elif_branches:
|
|
1465
|
+
if branch.HasField("block_body"):
|
|
1466
|
+
branch_blocks.append(branch.block_body)
|
|
1467
|
+
if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
|
|
1468
|
+
branch_blocks.append(cond.else_branch.block_body)
|
|
1469
|
+
|
|
1470
|
+
for block in branch_blocks:
|
|
1471
|
+
for var in self._detect_accumulator_targets(
|
|
1472
|
+
list(block.statements), in_scope_vars
|
|
1473
|
+
):
|
|
1474
|
+
if var not in seen:
|
|
1475
|
+
accumulators.append(var)
|
|
1476
|
+
seen.add(var)
|
|
1477
|
+
|
|
1478
|
+
return accumulators
|
|
1479
|
+
|
|
1480
|
+
def _extract_accumulator_from_stmt(
|
|
1481
|
+
self, stmt: ir.Statement, in_scope_vars: set
|
|
1482
|
+
) -> Optional[str]:
|
|
1483
|
+
"""Extract accumulator variable name from a single statement.
|
|
1484
|
+
|
|
1485
|
+
Returns the variable name if this statement modifies an out-of-scope variable,
|
|
1486
|
+
None otherwise.
|
|
1487
|
+
"""
|
|
1488
|
+
# Pattern 1: Method calls like list.append(), dict.update(), set.add()
|
|
1489
|
+
if stmt.HasField("expr_stmt"):
|
|
1490
|
+
expr = stmt.expr_stmt.expr
|
|
1491
|
+
if expr.HasField("function_call"):
|
|
1492
|
+
fn_name = expr.function_call.name
|
|
1493
|
+
# Check for mutating method calls: x.append, x.extend, x.add, x.update, etc.
|
|
1494
|
+
mutating_methods = {
|
|
1495
|
+
".append",
|
|
1496
|
+
".extend",
|
|
1497
|
+
".add",
|
|
1498
|
+
".update",
|
|
1499
|
+
".insert",
|
|
1500
|
+
".pop",
|
|
1501
|
+
".remove",
|
|
1502
|
+
".clear",
|
|
1503
|
+
}
|
|
1504
|
+
for method in mutating_methods:
|
|
1505
|
+
if fn_name.endswith(method):
|
|
1506
|
+
var_name = fn_name[: len(fn_name) - len(method)]
|
|
1507
|
+
# Only return if it's an out-of-scope variable
|
|
1508
|
+
if var_name and var_name not in in_scope_vars:
|
|
1509
|
+
return var_name
|
|
1510
|
+
|
|
1511
|
+
# Pattern 2: Subscript assignment like dict[key] = value
|
|
1512
|
+
if stmt.HasField("assignment"):
|
|
1513
|
+
for target in stmt.assignment.targets:
|
|
1514
|
+
# Check if target is a subscript pattern (contains '[')
|
|
1515
|
+
if "[" in target:
|
|
1516
|
+
# Extract base variable name (before '[')
|
|
1517
|
+
var_name = target.split("[")[0]
|
|
1518
|
+
if var_name and var_name not in in_scope_vars:
|
|
1519
|
+
return var_name
|
|
1520
|
+
|
|
1521
|
+
# Pattern 3: Self-referential assignment like x = x + [y]
|
|
1522
|
+
# The target variable is used on the RHS, so it must come from outside.
|
|
1523
|
+
# Note: We don't check in_scope_vars here because the assignment itself
|
|
1524
|
+
# would have added the target to in_scope_vars, but it still needs its
|
|
1525
|
+
# previous value from outside the loop body.
|
|
1526
|
+
if stmt.HasField("assignment"):
|
|
1527
|
+
assign = stmt.assignment
|
|
1528
|
+
rhs_vars = self._collect_variables_from_expr(assign.value)
|
|
1529
|
+
for target in assign.targets:
|
|
1530
|
+
if target in rhs_vars:
|
|
1531
|
+
return target
|
|
1532
|
+
|
|
1533
|
+
return None
|
|
1534
|
+
|
|
1535
|
+
def _collect_variables_from_expr(self, expr: ir.Expr) -> set:
|
|
1536
|
+
"""Recursively collect all variable names used in an expression."""
|
|
1537
|
+
vars_found: set = set()
|
|
1538
|
+
|
|
1539
|
+
if expr.HasField("variable"):
|
|
1540
|
+
vars_found.add(expr.variable.name)
|
|
1541
|
+
elif expr.HasField("binary_op"):
|
|
1542
|
+
vars_found.update(self._collect_variables_from_expr(expr.binary_op.left))
|
|
1543
|
+
vars_found.update(self._collect_variables_from_expr(expr.binary_op.right))
|
|
1544
|
+
elif expr.HasField("unary_op"):
|
|
1545
|
+
vars_found.update(self._collect_variables_from_expr(expr.unary_op.operand))
|
|
1546
|
+
elif expr.HasField("list"):
|
|
1547
|
+
for elem in expr.list.elements:
|
|
1548
|
+
vars_found.update(self._collect_variables_from_expr(elem))
|
|
1549
|
+
elif expr.HasField("dict"):
|
|
1550
|
+
for entry in expr.dict.entries:
|
|
1551
|
+
vars_found.update(self._collect_variables_from_expr(entry.key))
|
|
1552
|
+
vars_found.update(self._collect_variables_from_expr(entry.value))
|
|
1553
|
+
elif expr.HasField("index"):
|
|
1554
|
+
vars_found.update(self._collect_variables_from_expr(expr.index.value))
|
|
1555
|
+
vars_found.update(self._collect_variables_from_expr(expr.index.index))
|
|
1556
|
+
elif expr.HasField("dot"):
|
|
1557
|
+
vars_found.update(self._collect_variables_from_expr(expr.dot.object))
|
|
1558
|
+
elif expr.HasField("function_call"):
|
|
1559
|
+
for kwarg in expr.function_call.kwargs:
|
|
1560
|
+
vars_found.update(self._collect_variables_from_expr(kwarg.value))
|
|
1561
|
+
elif expr.HasField("action_call"):
|
|
1562
|
+
for kwarg in expr.action_call.kwargs:
|
|
1563
|
+
vars_found.update(self._collect_variables_from_expr(kwarg.value))
|
|
1564
|
+
|
|
1565
|
+
return vars_found
|
|
1566
|
+
|
|
1567
|
+
def _visit_if(self, node: ast.If) -> List[ir.Statement]:
|
|
1568
|
+
"""Convert if statement to IR.
|
|
1569
|
+
|
|
1570
|
+
Normalizes patterns like:
|
|
1571
|
+
if await some_action(...):
|
|
1572
|
+
...
|
|
1573
|
+
into:
|
|
1574
|
+
__if_cond_n__ = await some_action(...)
|
|
1575
|
+
if __if_cond_n__:
|
|
1576
|
+
...
|
|
1577
|
+
"""
|
|
1578
|
+
|
|
1579
|
+
def normalize_condition(test: ast.expr) -> tuple[List[ir.Statement], Optional[ir.Expr]]:
|
|
1580
|
+
action_call = self._extract_action_call(test)
|
|
1581
|
+
if action_call is None:
|
|
1582
|
+
return ([], self._expr_to_ir_with_model_coercion(test))
|
|
1583
|
+
|
|
1584
|
+
if not isinstance(test, ast.Await):
|
|
1585
|
+
line = getattr(test, "lineno", None)
|
|
1586
|
+
col = getattr(test, "col_offset", None)
|
|
1587
|
+
raise UnsupportedPatternError(
|
|
1588
|
+
"Action calls inside boolean expressions are not supported in if conditions",
|
|
1589
|
+
"Assign the awaited action result to a variable, then use the variable in the if condition.",
|
|
1590
|
+
line=line,
|
|
1591
|
+
col=col,
|
|
1592
|
+
)
|
|
1593
|
+
|
|
1594
|
+
cond_var = self._ctx.next_implicit_fn_name(prefix="if_cond")
|
|
1595
|
+
assign_stmt = ir.Statement(span=_make_span(test))
|
|
1596
|
+
assign_stmt.assignment.CopyFrom(
|
|
1597
|
+
ir.Assignment(
|
|
1598
|
+
targets=[cond_var],
|
|
1599
|
+
value=ir.Expr(action_call=action_call, span=_make_span(test)),
|
|
1600
|
+
)
|
|
1601
|
+
)
|
|
1602
|
+
cond_expr = ir.Expr(variable=ir.Variable(name=cond_var), span=_make_span(test))
|
|
1603
|
+
return ([assign_stmt], cond_expr)
|
|
1604
|
+
|
|
1605
|
+
def visit_body(nodes: list[ast.stmt]) -> List[ir.Statement]:
|
|
1606
|
+
stmts: List[ir.Statement] = []
|
|
1607
|
+
for body_node in nodes:
|
|
1608
|
+
stmts.extend(self._visit_statement(body_node))
|
|
1609
|
+
return stmts
|
|
1610
|
+
|
|
1611
|
+
# Collect if/elif branches as (test_expr, body_nodes)
|
|
1612
|
+
branches: list[tuple[ast.expr, list[ast.stmt], ast.AST]] = [(node.test, node.body, node)]
|
|
1613
|
+
current = node
|
|
1614
|
+
while current.orelse and len(current.orelse) == 1 and isinstance(current.orelse[0], ast.If):
|
|
1615
|
+
elif_node = current.orelse[0]
|
|
1616
|
+
branches.append((elif_node.test, elif_node.body, elif_node))
|
|
1617
|
+
current = elif_node
|
|
1618
|
+
|
|
1619
|
+
else_nodes = current.orelse
|
|
1620
|
+
|
|
1621
|
+
normalized: list[
|
|
1622
|
+
tuple[List[ir.Statement], Optional[ir.Expr], List[ir.Statement], ast.AST]
|
|
1623
|
+
] = []
|
|
1624
|
+
for test_expr, body_nodes, span_node in branches:
|
|
1625
|
+
prefix, cond = normalize_condition(test_expr)
|
|
1626
|
+
normalized.append((prefix, cond, visit_body(body_nodes), span_node))
|
|
1627
|
+
|
|
1628
|
+
else_body = visit_body(else_nodes) if else_nodes else []
|
|
1629
|
+
|
|
1630
|
+
# If any non-first branch needs normalization, preserve Python semantics by nesting.
|
|
1631
|
+
requires_nested = any(prefix for prefix, _, _, _ in normalized[1:])
|
|
1632
|
+
|
|
1633
|
+
def build_conditional_stmt(
|
|
1634
|
+
condition: ir.Expr,
|
|
1635
|
+
then_body: List[ir.Statement],
|
|
1636
|
+
else_body_statements: List[ir.Statement],
|
|
1637
|
+
span_node: ast.AST,
|
|
1638
|
+
) -> ir.Statement:
|
|
1639
|
+
conditional_stmt = ir.Statement(span=_make_span(span_node))
|
|
1640
|
+
if_branch = ir.IfBranch(
|
|
1641
|
+
condition=condition,
|
|
1642
|
+
block_body=ir.Block(statements=then_body, span=_make_span(span_node)),
|
|
1643
|
+
span=_make_span(span_node),
|
|
1644
|
+
)
|
|
1645
|
+
conditional = ir.Conditional(if_branch=if_branch)
|
|
1646
|
+
if else_body_statements:
|
|
1647
|
+
else_branch = ir.ElseBranch(
|
|
1648
|
+
block_body=ir.Block(
|
|
1649
|
+
statements=else_body_statements,
|
|
1650
|
+
span=_make_span(span_node),
|
|
1651
|
+
),
|
|
1652
|
+
span=_make_span(span_node),
|
|
1653
|
+
)
|
|
1654
|
+
conditional.else_branch.CopyFrom(else_branch)
|
|
1655
|
+
conditional_stmt.conditional.CopyFrom(conditional)
|
|
1656
|
+
return conditional_stmt
|
|
1657
|
+
|
|
1658
|
+
if requires_nested:
|
|
1659
|
+
nested_else: List[ir.Statement] = else_body
|
|
1660
|
+
for prefix, cond, then_body, span_node in reversed(normalized):
|
|
1661
|
+
if cond is None:
|
|
1662
|
+
continue
|
|
1663
|
+
nested_if_stmt = build_conditional_stmt(
|
|
1664
|
+
condition=cond,
|
|
1665
|
+
then_body=then_body,
|
|
1666
|
+
else_body_statements=nested_else,
|
|
1667
|
+
span_node=span_node,
|
|
1668
|
+
)
|
|
1669
|
+
nested_else = [*prefix, nested_if_stmt]
|
|
1670
|
+
return nested_else
|
|
1671
|
+
|
|
1672
|
+
# Flat conditional with elif/else (original behavior), plus optional prefix for the if guard.
|
|
1673
|
+
if_prefix, if_condition, if_body, if_span_node = normalized[0]
|
|
1674
|
+
if if_condition is None:
|
|
1675
|
+
return []
|
|
1676
|
+
|
|
1677
|
+
conditional_stmt = ir.Statement(span=_make_span(if_span_node))
|
|
1678
|
+
if_branch = ir.IfBranch(
|
|
1679
|
+
condition=if_condition,
|
|
1680
|
+
block_body=ir.Block(statements=if_body, span=_make_span(if_span_node)),
|
|
1681
|
+
span=_make_span(if_span_node),
|
|
1682
|
+
)
|
|
1683
|
+
conditional = ir.Conditional(if_branch=if_branch)
|
|
1684
|
+
|
|
1685
|
+
for _, elif_condition, elif_body, elif_span_node in normalized[1:]:
|
|
1686
|
+
if elif_condition is None:
|
|
1687
|
+
continue
|
|
1688
|
+
elif_branch = ir.ElifBranch(
|
|
1689
|
+
condition=elif_condition,
|
|
1690
|
+
block_body=ir.Block(statements=elif_body, span=_make_span(elif_span_node)),
|
|
1691
|
+
span=_make_span(elif_span_node),
|
|
1692
|
+
)
|
|
1693
|
+
conditional.elif_branches.append(elif_branch)
|
|
1694
|
+
|
|
1695
|
+
if else_body:
|
|
1696
|
+
else_branch = ir.ElseBranch(
|
|
1697
|
+
block_body=ir.Block(statements=else_body, span=_make_span(if_span_node)),
|
|
1698
|
+
span=_make_span(if_span_node),
|
|
1699
|
+
)
|
|
1700
|
+
conditional.else_branch.CopyFrom(else_branch)
|
|
1701
|
+
|
|
1702
|
+
conditional_stmt.conditional.CopyFrom(conditional)
|
|
1703
|
+
return [*if_prefix, conditional_stmt]
|
|
1704
|
+
|
|
1705
|
+
def _collect_assigned_vars(self, stmts: List[ir.Statement]) -> set:
|
|
1706
|
+
"""Collect all variable names assigned in a list of statements."""
|
|
1707
|
+
assigned = set()
|
|
1708
|
+
for stmt in stmts:
|
|
1709
|
+
if stmt.HasField("assignment"):
|
|
1710
|
+
assigned.update(stmt.assignment.targets)
|
|
1711
|
+
return assigned
|
|
1712
|
+
|
|
1713
|
+
def _collect_assigned_vars_in_order(self, stmts: List[ir.Statement]) -> list[str]:
|
|
1714
|
+
"""Collect assigned variable names in statement order (deduplicated)."""
|
|
1715
|
+
assigned: list[str] = []
|
|
1716
|
+
seen: set[str] = set()
|
|
1717
|
+
|
|
1718
|
+
for stmt in stmts:
|
|
1719
|
+
if stmt.HasField("assignment"):
|
|
1720
|
+
for target in stmt.assignment.targets:
|
|
1721
|
+
if target not in seen:
|
|
1722
|
+
seen.add(target)
|
|
1723
|
+
assigned.append(target)
|
|
1724
|
+
|
|
1725
|
+
if stmt.HasField("conditional"):
|
|
1726
|
+
cond = stmt.conditional
|
|
1727
|
+
if cond.HasField("if_branch") and cond.if_branch.HasField("block_body"):
|
|
1728
|
+
for target in self._collect_assigned_vars_in_order(
|
|
1729
|
+
list(cond.if_branch.block_body.statements)
|
|
1730
|
+
):
|
|
1731
|
+
if target not in seen:
|
|
1732
|
+
seen.add(target)
|
|
1733
|
+
assigned.append(target)
|
|
1734
|
+
for elif_branch in cond.elif_branches:
|
|
1735
|
+
if elif_branch.HasField("block_body"):
|
|
1736
|
+
for target in self._collect_assigned_vars_in_order(
|
|
1737
|
+
list(elif_branch.block_body.statements)
|
|
1738
|
+
):
|
|
1739
|
+
if target not in seen:
|
|
1740
|
+
seen.add(target)
|
|
1741
|
+
assigned.append(target)
|
|
1742
|
+
if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
|
|
1743
|
+
for target in self._collect_assigned_vars_in_order(
|
|
1744
|
+
list(cond.else_branch.block_body.statements)
|
|
1745
|
+
):
|
|
1746
|
+
if target not in seen:
|
|
1747
|
+
seen.add(target)
|
|
1748
|
+
assigned.append(target)
|
|
1749
|
+
|
|
1750
|
+
if stmt.HasField("for_loop") and stmt.for_loop.HasField("block_body"):
|
|
1751
|
+
for target in self._collect_assigned_vars_in_order(
|
|
1752
|
+
list(stmt.for_loop.block_body.statements)
|
|
1753
|
+
):
|
|
1754
|
+
if target not in seen:
|
|
1755
|
+
seen.add(target)
|
|
1756
|
+
assigned.append(target)
|
|
1757
|
+
|
|
1758
|
+
if stmt.HasField("try_except"):
|
|
1759
|
+
try_block = stmt.try_except.try_block
|
|
1760
|
+
if try_block.HasField("span"):
|
|
1761
|
+
for target in self._collect_assigned_vars_in_order(list(try_block.statements)):
|
|
1762
|
+
if target not in seen:
|
|
1763
|
+
seen.add(target)
|
|
1764
|
+
assigned.append(target)
|
|
1765
|
+
for handler in stmt.try_except.handlers:
|
|
1766
|
+
if handler.HasField("block_body"):
|
|
1767
|
+
for target in self._collect_assigned_vars_in_order(
|
|
1768
|
+
list(handler.block_body.statements)
|
|
1769
|
+
):
|
|
1770
|
+
if target not in seen:
|
|
1771
|
+
seen.add(target)
|
|
1772
|
+
assigned.append(target)
|
|
1773
|
+
|
|
1774
|
+
return assigned
|
|
1775
|
+
|
|
1776
|
+
def _collect_variables_from_block(self, block: ir.Block) -> list[str]:
|
|
1777
|
+
return self._collect_variables_from_statements(list(block.statements))
|
|
1778
|
+
|
|
1779
|
+
def _collect_variables_from_statements(self, stmts: List[ir.Statement]) -> list[str]:
|
|
1780
|
+
"""Collect variable references from statements in encounter order."""
|
|
1781
|
+
vars_found: list[str] = []
|
|
1782
|
+
seen: set[str] = set()
|
|
1783
|
+
|
|
1784
|
+
for stmt in stmts:
|
|
1785
|
+
if stmt.HasField("assignment") and stmt.assignment.HasField("value"):
|
|
1786
|
+
for var in self._collect_variables_from_expr(stmt.assignment.value):
|
|
1787
|
+
if var not in seen:
|
|
1788
|
+
seen.add(var)
|
|
1789
|
+
vars_found.append(var)
|
|
1790
|
+
|
|
1791
|
+
if stmt.HasField("return_stmt") and stmt.return_stmt.HasField("value"):
|
|
1792
|
+
for var in self._collect_variables_from_expr(stmt.return_stmt.value):
|
|
1793
|
+
if var not in seen:
|
|
1794
|
+
seen.add(var)
|
|
1795
|
+
vars_found.append(var)
|
|
1796
|
+
|
|
1797
|
+
if stmt.HasField("action_call"):
|
|
1798
|
+
expr = ir.Expr(action_call=stmt.action_call, span=stmt.span)
|
|
1799
|
+
for var in self._collect_variables_from_expr(expr):
|
|
1800
|
+
if var not in seen:
|
|
1801
|
+
seen.add(var)
|
|
1802
|
+
vars_found.append(var)
|
|
1803
|
+
|
|
1804
|
+
if stmt.HasField("expr_stmt"):
|
|
1805
|
+
for var in self._collect_variables_from_expr(stmt.expr_stmt.expr):
|
|
1806
|
+
if var not in seen:
|
|
1807
|
+
seen.add(var)
|
|
1808
|
+
vars_found.append(var)
|
|
1809
|
+
|
|
1810
|
+
if stmt.HasField("conditional"):
|
|
1811
|
+
cond = stmt.conditional
|
|
1812
|
+
if cond.HasField("if_branch"):
|
|
1813
|
+
if cond.if_branch.HasField("condition"):
|
|
1814
|
+
for var in self._collect_variables_from_expr(cond.if_branch.condition):
|
|
1815
|
+
if var not in seen:
|
|
1816
|
+
seen.add(var)
|
|
1817
|
+
vars_found.append(var)
|
|
1818
|
+
if cond.if_branch.HasField("block_body"):
|
|
1819
|
+
for var in self._collect_variables_from_block(cond.if_branch.block_body):
|
|
1820
|
+
if var not in seen:
|
|
1821
|
+
seen.add(var)
|
|
1822
|
+
vars_found.append(var)
|
|
1823
|
+
for elif_branch in cond.elif_branches:
|
|
1824
|
+
if elif_branch.HasField("condition"):
|
|
1825
|
+
for var in self._collect_variables_from_expr(elif_branch.condition):
|
|
1826
|
+
if var not in seen:
|
|
1827
|
+
seen.add(var)
|
|
1828
|
+
vars_found.append(var)
|
|
1829
|
+
if elif_branch.HasField("block_body"):
|
|
1830
|
+
for var in self._collect_variables_from_block(elif_branch.block_body):
|
|
1831
|
+
if var not in seen:
|
|
1832
|
+
seen.add(var)
|
|
1833
|
+
vars_found.append(var)
|
|
1834
|
+
if cond.HasField("else_branch") and cond.else_branch.HasField("block_body"):
|
|
1835
|
+
for var in self._collect_variables_from_block(cond.else_branch.block_body):
|
|
1836
|
+
if var not in seen:
|
|
1837
|
+
seen.add(var)
|
|
1838
|
+
vars_found.append(var)
|
|
1839
|
+
|
|
1840
|
+
if stmt.HasField("for_loop"):
|
|
1841
|
+
fl = stmt.for_loop
|
|
1842
|
+
if fl.HasField("iterable"):
|
|
1843
|
+
for var in self._collect_variables_from_expr(fl.iterable):
|
|
1844
|
+
if var not in seen:
|
|
1845
|
+
seen.add(var)
|
|
1846
|
+
vars_found.append(var)
|
|
1847
|
+
if fl.HasField("block_body"):
|
|
1848
|
+
for var in self._collect_variables_from_block(fl.block_body):
|
|
1849
|
+
if var not in seen:
|
|
1850
|
+
seen.add(var)
|
|
1851
|
+
vars_found.append(var)
|
|
1852
|
+
|
|
1853
|
+
if stmt.HasField("try_except"):
|
|
1854
|
+
te = stmt.try_except
|
|
1855
|
+
if te.HasField("try_block"):
|
|
1856
|
+
for var in self._collect_variables_from_block(te.try_block):
|
|
1857
|
+
if var not in seen:
|
|
1858
|
+
seen.add(var)
|
|
1859
|
+
vars_found.append(var)
|
|
1860
|
+
for handler in te.handlers:
|
|
1861
|
+
if handler.HasField("block_body"):
|
|
1862
|
+
for var in self._collect_variables_from_block(handler.block_body):
|
|
1863
|
+
if var not in seen:
|
|
1864
|
+
seen.add(var)
|
|
1865
|
+
vars_found.append(var)
|
|
1866
|
+
|
|
1867
|
+
if stmt.HasField("parallel_block"):
|
|
1868
|
+
for call in stmt.parallel_block.calls:
|
|
1869
|
+
if call.HasField("action"):
|
|
1870
|
+
for kwarg in call.action.kwargs:
|
|
1871
|
+
for var in self._collect_variables_from_expr(kwarg.value):
|
|
1872
|
+
if var not in seen:
|
|
1873
|
+
seen.add(var)
|
|
1874
|
+
vars_found.append(var)
|
|
1875
|
+
elif call.HasField("function"):
|
|
1876
|
+
for kwarg in call.function.kwargs:
|
|
1877
|
+
for var in self._collect_variables_from_expr(kwarg.value):
|
|
1878
|
+
if var not in seen:
|
|
1879
|
+
seen.add(var)
|
|
1880
|
+
vars_found.append(var)
|
|
1881
|
+
|
|
1882
|
+
if stmt.HasField("spread_action"):
|
|
1883
|
+
spread = stmt.spread_action
|
|
1884
|
+
if spread.HasField("collection"):
|
|
1885
|
+
for var in self._collect_variables_from_expr(spread.collection):
|
|
1886
|
+
if var not in seen:
|
|
1887
|
+
seen.add(var)
|
|
1888
|
+
vars_found.append(var)
|
|
1889
|
+
if spread.HasField("action"):
|
|
1890
|
+
for kwarg in spread.action.kwargs:
|
|
1891
|
+
for var in self._collect_variables_from_expr(kwarg.value):
|
|
1892
|
+
if var not in seen:
|
|
1893
|
+
seen.add(var)
|
|
1894
|
+
vars_found.append(var)
|
|
1895
|
+
|
|
1896
|
+
return vars_found
|
|
1897
|
+
|
|
1898
|
+
def _visit_try(self, node: ast.Try) -> List[ir.Statement]:
|
|
1899
|
+
"""Convert try/except to IR with full block bodies."""
|
|
1900
|
+
# Build try body statements (recursively transforms nested structures)
|
|
1901
|
+
try_body: List[ir.Statement] = []
|
|
1902
|
+
for body_node in node.body:
|
|
1903
|
+
stmts = self._visit_statement(body_node)
|
|
1904
|
+
try_body.extend(stmts)
|
|
1905
|
+
|
|
1906
|
+
# Build exception handlers (with wrapping if needed)
|
|
1907
|
+
handlers: List[ir.ExceptHandler] = []
|
|
1908
|
+
for handler in node.handlers:
|
|
1909
|
+
exception_types: List[str] = []
|
|
1910
|
+
if handler.type:
|
|
1911
|
+
if isinstance(handler.type, ast.Name):
|
|
1912
|
+
exception_types.append(handler.type.id)
|
|
1913
|
+
elif isinstance(handler.type, ast.Tuple):
|
|
1914
|
+
for elt in handler.type.elts:
|
|
1915
|
+
if isinstance(elt, ast.Name):
|
|
1916
|
+
exception_types.append(elt.id)
|
|
1917
|
+
|
|
1918
|
+
# Build handler body (recursively transforms nested structures)
|
|
1919
|
+
handler_body: List[ir.Statement] = []
|
|
1920
|
+
for handler_node in handler.body:
|
|
1921
|
+
stmts = self._visit_statement(handler_node)
|
|
1922
|
+
handler_body.extend(stmts)
|
|
1923
|
+
|
|
1924
|
+
except_handler = ir.ExceptHandler(
|
|
1925
|
+
exception_types=exception_types,
|
|
1926
|
+
block_body=ir.Block(statements=handler_body, span=_make_span(handler)),
|
|
1927
|
+
span=_make_span(handler),
|
|
1928
|
+
)
|
|
1929
|
+
if handler.name:
|
|
1930
|
+
except_handler.exception_var = handler.name
|
|
1931
|
+
handlers.append(except_handler)
|
|
1932
|
+
|
|
1933
|
+
# Build the try/except statement
|
|
1934
|
+
try_stmt = ir.Statement(span=_make_span(node))
|
|
1935
|
+
try_except = ir.TryExcept(
|
|
1936
|
+
handlers=handlers,
|
|
1937
|
+
try_block=ir.Block(statements=try_body, span=_make_span(node)),
|
|
1938
|
+
)
|
|
1939
|
+
try_stmt.try_except.CopyFrom(try_except)
|
|
1940
|
+
|
|
1941
|
+
return [try_stmt]
|
|
1942
|
+
|
|
1943
|
+
def _count_calls(self, stmts: List[ir.Statement]) -> int:
|
|
1944
|
+
"""Count action calls and function calls in statements.
|
|
1945
|
+
|
|
1946
|
+
Both action calls and function calls (including synthetic functions)
|
|
1947
|
+
count toward the limit of one call per control flow body.
|
|
1948
|
+
"""
|
|
1949
|
+
count = 0
|
|
1950
|
+
for stmt in stmts:
|
|
1951
|
+
if stmt.HasField("action_call"):
|
|
1952
|
+
count += 1
|
|
1953
|
+
elif stmt.HasField("assignment"):
|
|
1954
|
+
# Check if assignment value is an action call or function call
|
|
1955
|
+
if stmt.assignment.value.HasField("action_call"):
|
|
1956
|
+
count += 1
|
|
1957
|
+
elif stmt.assignment.value.HasField("function_call"):
|
|
1958
|
+
count += 1
|
|
1959
|
+
elif stmt.HasField("expr_stmt"):
|
|
1960
|
+
# Check if expression is a function call
|
|
1961
|
+
if stmt.expr_stmt.expr.HasField("function_call"):
|
|
1962
|
+
count += 1
|
|
1963
|
+
return count
|
|
1964
|
+
|
|
1965
|
+
def _wrap_body_as_function(
|
|
1966
|
+
self,
|
|
1967
|
+
body: List[ir.Statement],
|
|
1968
|
+
prefix: str,
|
|
1969
|
+
node: ast.AST,
|
|
1970
|
+
inputs: Optional[List[str]] = None,
|
|
1971
|
+
modified_vars: Optional[List[str]] = None,
|
|
1972
|
+
) -> List[ir.Statement]:
|
|
1973
|
+
"""Wrap a body with multiple calls into a synthetic function.
|
|
1974
|
+
|
|
1975
|
+
Args:
|
|
1976
|
+
body: The statements to wrap
|
|
1977
|
+
prefix: Name prefix for the synthetic function
|
|
1978
|
+
node: AST node for span information
|
|
1979
|
+
inputs: Variables to pass as inputs (e.g., loop variables)
|
|
1980
|
+
modified_vars: Out-of-scope variables modified in the body.
|
|
1981
|
+
These are added as inputs AND returned as outputs,
|
|
1982
|
+
enabling functional transformation of external state.
|
|
1983
|
+
|
|
1984
|
+
Returns a list containing a single function call statement (or assignment
|
|
1985
|
+
if modified_vars are present).
|
|
1986
|
+
"""
|
|
1987
|
+
fn_name = self._ctx.next_implicit_fn_name(prefix)
|
|
1988
|
+
fn_inputs = list(inputs or [])
|
|
1989
|
+
|
|
1990
|
+
# Add modified variables as inputs (they need to be passed in)
|
|
1991
|
+
modified_vars = modified_vars or []
|
|
1992
|
+
for var in modified_vars:
|
|
1993
|
+
if var not in fn_inputs:
|
|
1994
|
+
fn_inputs.append(var)
|
|
1995
|
+
|
|
1996
|
+
# If there are modified variables, add a return statement for them
|
|
1997
|
+
wrapped_body = list(body)
|
|
1998
|
+
if modified_vars:
|
|
1999
|
+
# Create return statement: return (var1, var2, ...) or return var1
|
|
2000
|
+
if len(modified_vars) == 1:
|
|
2001
|
+
return_expr = ir.Expr(
|
|
2002
|
+
variable=ir.Variable(name=modified_vars[0]),
|
|
2003
|
+
span=_make_span(node),
|
|
2004
|
+
)
|
|
2005
|
+
else:
|
|
2006
|
+
# Return as list (tuples are represented as lists in IR)
|
|
2007
|
+
return_expr = ir.Expr(
|
|
2008
|
+
list=ir.ListExpr(
|
|
2009
|
+
elements=[ir.Expr(variable=ir.Variable(name=var)) for var in modified_vars]
|
|
2010
|
+
),
|
|
2011
|
+
span=_make_span(node),
|
|
2012
|
+
)
|
|
2013
|
+
return_stmt = ir.Statement(span=_make_span(node))
|
|
2014
|
+
return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=return_expr))
|
|
2015
|
+
wrapped_body.append(return_stmt)
|
|
2016
|
+
|
|
2017
|
+
# Create the synthetic function
|
|
2018
|
+
implicit_fn = ir.FunctionDef(
|
|
2019
|
+
name=fn_name,
|
|
2020
|
+
io=ir.IoDecl(inputs=fn_inputs, outputs=modified_vars),
|
|
2021
|
+
body=ir.Block(statements=wrapped_body),
|
|
2022
|
+
span=_make_span(node),
|
|
2023
|
+
)
|
|
2024
|
+
self._ctx.implicit_functions.append(implicit_fn)
|
|
2025
|
+
|
|
2026
|
+
# Create a function call expression
|
|
2027
|
+
kwargs = [
|
|
2028
|
+
ir.Kwarg(name=var, value=ir.Expr(variable=ir.Variable(name=var))) for var in fn_inputs
|
|
2029
|
+
]
|
|
2030
|
+
fn_call_expr = ir.Expr(
|
|
2031
|
+
function_call=ir.FunctionCall(name=fn_name, kwargs=kwargs),
|
|
2032
|
+
span=_make_span(node),
|
|
2033
|
+
)
|
|
2034
|
+
|
|
2035
|
+
# If there are modified variables, create an assignment statement
|
|
2036
|
+
# so the returned values are assigned back to the variables
|
|
2037
|
+
call_stmt = ir.Statement(span=_make_span(node))
|
|
2038
|
+
if modified_vars:
|
|
2039
|
+
# Create assignment: var1, var2 = fn(...) or var1 = fn(...)
|
|
2040
|
+
assign = ir.Assignment(value=fn_call_expr)
|
|
2041
|
+
assign.targets.extend(modified_vars)
|
|
2042
|
+
call_stmt.assignment.CopyFrom(assign)
|
|
2043
|
+
else:
|
|
2044
|
+
call_stmt.expr_stmt.CopyFrom(ir.ExprStmt(expr=fn_call_expr))
|
|
2045
|
+
|
|
2046
|
+
return [call_stmt]
|
|
2047
|
+
|
|
2048
|
+
def _visit_return(self, node: ast.Return) -> List[ir.Statement]:
|
|
2049
|
+
"""Convert return statement to IR.
|
|
2050
|
+
|
|
2051
|
+
Return statements should only contain variables or literals, not action calls.
|
|
2052
|
+
If the return contains an action call, we normalize it:
|
|
2053
|
+
return await action()
|
|
2054
|
+
becomes:
|
|
2055
|
+
_return_tmp = await action()
|
|
2056
|
+
return _return_tmp
|
|
2057
|
+
|
|
2058
|
+
Constructor calls (like return MyModel(...)) are not supported and will
|
|
2059
|
+
raise an error with a recommendation to use an @action instead.
|
|
2060
|
+
"""
|
|
2061
|
+
if node.value:
|
|
2062
|
+
# Check for constructor calls in return (e.g., return MyModel(...))
|
|
2063
|
+
self._check_constructor_in_return(node.value)
|
|
2064
|
+
|
|
2065
|
+
# Check if returning an action call - normalize to assignment + return
|
|
2066
|
+
action_call = self._extract_action_call(node.value)
|
|
2067
|
+
if action_call:
|
|
2068
|
+
# Create a temporary variable for the action result
|
|
2069
|
+
tmp_var = "_return_tmp"
|
|
2070
|
+
|
|
2071
|
+
# Create assignment: _return_tmp = await action()
|
|
2072
|
+
assign_stmt = ir.Statement(span=_make_span(node))
|
|
2073
|
+
value = ir.Expr(action_call=action_call, span=_make_span(node))
|
|
2074
|
+
assign = ir.Assignment(targets=[tmp_var], value=value)
|
|
2075
|
+
assign_stmt.assignment.CopyFrom(assign)
|
|
2076
|
+
|
|
2077
|
+
# Create return: return _return_tmp
|
|
2078
|
+
return_stmt = ir.Statement(span=_make_span(node))
|
|
2079
|
+
var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
|
|
2080
|
+
ret = ir.ReturnStmt(value=var_expr)
|
|
2081
|
+
return_stmt.return_stmt.CopyFrom(ret)
|
|
2082
|
+
|
|
2083
|
+
return [assign_stmt, return_stmt]
|
|
2084
|
+
|
|
2085
|
+
# Normalize return of function calls into assignment + return
|
|
2086
|
+
expr = self._expr_to_ir_with_model_coercion(node.value)
|
|
2087
|
+
if expr and expr.HasField("function_call"):
|
|
2088
|
+
tmp_var = self._ctx.next_implicit_fn_name(prefix="return_tmp")
|
|
2089
|
+
|
|
2090
|
+
assign_stmt = ir.Statement(span=_make_span(node))
|
|
2091
|
+
assign_stmt.assignment.CopyFrom(ir.Assignment(targets=[tmp_var], value=expr))
|
|
2092
|
+
|
|
2093
|
+
return_stmt = ir.Statement(span=_make_span(node))
|
|
2094
|
+
var_expr = ir.Expr(variable=ir.Variable(name=tmp_var), span=_make_span(node))
|
|
2095
|
+
return_stmt.return_stmt.CopyFrom(ir.ReturnStmt(value=var_expr))
|
|
2096
|
+
return [assign_stmt, return_stmt]
|
|
2097
|
+
|
|
2098
|
+
# Regular return with expression (variable, literal, etc.)
|
|
2099
|
+
if expr:
|
|
2100
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
2101
|
+
return_stmt = ir.ReturnStmt(value=expr)
|
|
2102
|
+
stmt.return_stmt.CopyFrom(return_stmt)
|
|
2103
|
+
return [stmt]
|
|
2104
|
+
|
|
2105
|
+
# Return with no value
|
|
2106
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
2107
|
+
stmt.return_stmt.CopyFrom(ir.ReturnStmt())
|
|
2108
|
+
return [stmt]
|
|
2109
|
+
|
|
2110
|
+
def _visit_break(self, node: ast.Break) -> List[ir.Statement]:
|
|
2111
|
+
"""Convert break statement to IR."""
|
|
2112
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
2113
|
+
stmt.break_stmt.CopyFrom(ir.BreakStmt())
|
|
2114
|
+
return [stmt]
|
|
2115
|
+
|
|
2116
|
+
def _visit_continue(self, node: ast.Continue) -> List[ir.Statement]:
|
|
2117
|
+
"""Convert continue statement to IR."""
|
|
2118
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
2119
|
+
stmt.continue_stmt.CopyFrom(ir.ContinueStmt())
|
|
2120
|
+
return [stmt]
|
|
2121
|
+
|
|
2122
|
+
def _visit_aug_assign(self, node: ast.AugAssign) -> List[ir.Statement]:
|
|
2123
|
+
"""Convert augmented assignment (+=, -=, etc.) to IR."""
|
|
2124
|
+
# For now, we can represent this as a regular assignment with binary op
|
|
2125
|
+
# target op= value -> target = target op value
|
|
2126
|
+
stmt = ir.Statement(span=_make_span(node))
|
|
2127
|
+
|
|
2128
|
+
targets: List[str] = []
|
|
2129
|
+
if isinstance(node.target, ast.Name):
|
|
2130
|
+
targets.append(node.target.id)
|
|
2131
|
+
|
|
2132
|
+
left = self._expr_to_ir_with_model_coercion(node.target)
|
|
2133
|
+
right = self._expr_to_ir_with_model_coercion(node.value)
|
|
2134
|
+
if right and right.HasField("function_call"):
|
|
2135
|
+
tmp_var = self._ctx.next_implicit_fn_name(prefix="aug_tmp")
|
|
2136
|
+
|
|
2137
|
+
assign_tmp = ir.Statement(span=_make_span(node))
|
|
2138
|
+
assign_tmp.assignment.CopyFrom(
|
|
2139
|
+
ir.Assignment(
|
|
2140
|
+
targets=[tmp_var],
|
|
2141
|
+
value=ir.Expr(function_call=right.function_call, span=_make_span(node)),
|
|
2142
|
+
)
|
|
2143
|
+
)
|
|
2144
|
+
|
|
2145
|
+
if left:
|
|
2146
|
+
op = _bin_op_to_ir(node.op)
|
|
2147
|
+
if op:
|
|
2148
|
+
binary = ir.BinaryOp(
|
|
2149
|
+
left=left,
|
|
2150
|
+
op=op,
|
|
2151
|
+
right=ir.Expr(variable=ir.Variable(name=tmp_var)),
|
|
2152
|
+
)
|
|
2153
|
+
value = ir.Expr(binary_op=binary)
|
|
2154
|
+
assign = ir.Assignment(targets=targets, value=value)
|
|
2155
|
+
stmt.assignment.CopyFrom(assign)
|
|
2156
|
+
return [assign_tmp, stmt]
|
|
2157
|
+
return [assign_tmp]
|
|
2158
|
+
|
|
2159
|
+
if left and right:
|
|
2160
|
+
op = _bin_op_to_ir(node.op)
|
|
2161
|
+
if op:
|
|
2162
|
+
binary = ir.BinaryOp(left=left, op=op, right=right)
|
|
2163
|
+
value = ir.Expr(binary_op=binary)
|
|
2164
|
+
assign = ir.Assignment(targets=targets, value=value)
|
|
2165
|
+
stmt.assignment.CopyFrom(assign)
|
|
2166
|
+
return [stmt]
|
|
2167
|
+
|
|
2168
|
+
return []
|
|
2169
|
+
|
|
2170
|
+
def _collect_function_inputs(self, node: ast.FunctionDef | ast.AsyncFunctionDef) -> List[str]:
|
|
2171
|
+
"""Collect workflow inputs from function parameters, including kw-only args."""
|
|
2172
|
+
args: List[str] = []
|
|
2173
|
+
seen: set[str] = set()
|
|
2174
|
+
|
|
2175
|
+
ordered_args = list(node.args.posonlyargs) + list(node.args.args)
|
|
2176
|
+
if ordered_args and ordered_args[0].arg == "self":
|
|
2177
|
+
ordered_args = ordered_args[1:]
|
|
2178
|
+
|
|
2179
|
+
for arg in ordered_args:
|
|
2180
|
+
if arg.arg not in seen:
|
|
2181
|
+
args.append(arg.arg)
|
|
2182
|
+
seen.add(arg.arg)
|
|
2183
|
+
|
|
2184
|
+
if node.args.vararg and node.args.vararg.arg not in seen:
|
|
2185
|
+
args.append(node.args.vararg.arg)
|
|
2186
|
+
seen.add(node.args.vararg.arg)
|
|
2187
|
+
|
|
2188
|
+
for arg in node.args.kwonlyargs:
|
|
2189
|
+
if arg.arg not in seen:
|
|
2190
|
+
args.append(arg.arg)
|
|
2191
|
+
seen.add(arg.arg)
|
|
2192
|
+
|
|
2193
|
+
if node.args.kwarg and node.args.kwarg.arg not in seen:
|
|
2194
|
+
args.append(node.args.kwarg.arg)
|
|
2195
|
+
seen.add(node.args.kwarg.arg)
|
|
2196
|
+
|
|
2197
|
+
return args
|
|
2198
|
+
|
|
2199
|
+
def _check_constructor_in_return(self, node: ast.expr) -> None:
|
|
2200
|
+
"""Check for constructor calls in return statements.
|
|
2201
|
+
|
|
2202
|
+
Raises UnsupportedPatternError if the return value is a class instantiation
|
|
2203
|
+
like: return MyModel(field=value)
|
|
2204
|
+
|
|
2205
|
+
This is not supported because the workflow IR cannot serialize arbitrary
|
|
2206
|
+
object instantiation. Users should use an @action to create objects.
|
|
2207
|
+
"""
|
|
2208
|
+
# Skip if it's an await (action call) - those are fine
|
|
2209
|
+
if isinstance(node, ast.Await):
|
|
2210
|
+
return
|
|
2211
|
+
|
|
2212
|
+
# Check for direct Call that looks like a constructor
|
|
2213
|
+
if isinstance(node, ast.Call):
|
|
2214
|
+
func_name = self._get_constructor_name(node.func)
|
|
2215
|
+
if func_name and self._looks_like_constructor(func_name, node):
|
|
2216
|
+
line = getattr(node, "lineno", None)
|
|
2217
|
+
col = getattr(node, "col_offset", None)
|
|
2218
|
+
raise UnsupportedPatternError(
|
|
2219
|
+
f"Returning constructor call '{func_name}(...)' is not supported",
|
|
2220
|
+
RECOMMENDATIONS["constructor_return"],
|
|
2221
|
+
line=line,
|
|
2222
|
+
col=col,
|
|
2223
|
+
)
|
|
2224
|
+
|
|
2225
|
+
def _check_constructor_in_assignment(self, node: ast.expr) -> None:
|
|
2226
|
+
"""Check for constructor calls in assignments.
|
|
2227
|
+
|
|
2228
|
+
Raises UnsupportedPatternError if the assignment value is a class instantiation
|
|
2229
|
+
like: result = MyModel(field=value)
|
|
2230
|
+
|
|
2231
|
+
This is not supported because the workflow IR cannot serialize arbitrary
|
|
2232
|
+
object instantiation. Users should use an @action to create objects.
|
|
2233
|
+
"""
|
|
2234
|
+
# Skip if it's an await (action call) - those are fine
|
|
2235
|
+
if isinstance(node, ast.Await):
|
|
2236
|
+
return
|
|
2237
|
+
|
|
2238
|
+
# Check for direct Call that looks like a constructor
|
|
2239
|
+
if isinstance(node, ast.Call):
|
|
2240
|
+
func_name = self._get_constructor_name(node.func)
|
|
2241
|
+
if func_name and self._looks_like_constructor(func_name, node):
|
|
2242
|
+
line = getattr(node, "lineno", None)
|
|
2243
|
+
col = getattr(node, "col_offset", None)
|
|
2244
|
+
raise UnsupportedPatternError(
|
|
2245
|
+
f"Assigning constructor call '{func_name}(...)' is not supported",
|
|
2246
|
+
RECOMMENDATIONS["constructor_assignment"],
|
|
2247
|
+
line=line,
|
|
2248
|
+
col=col,
|
|
2249
|
+
)
|
|
2250
|
+
|
|
2251
|
+
def _get_constructor_name(self, func: ast.expr) -> Optional[str]:
|
|
2252
|
+
"""Get the name from a function expression if it looks like a constructor."""
|
|
2253
|
+
if isinstance(func, ast.Name):
|
|
2254
|
+
return func.id
|
|
2255
|
+
elif isinstance(func, ast.Attribute):
|
|
2256
|
+
return func.attr
|
|
2257
|
+
return None
|
|
2258
|
+
|
|
2259
|
+
def _looks_like_constructor(self, func_name: str, call: ast.Call) -> bool:
|
|
2260
|
+
"""Check if a function call looks like a class constructor.
|
|
2261
|
+
|
|
2262
|
+
A constructor is identified by:
|
|
2263
|
+
1. Name starts with uppercase (PEP8 convention for classes)
|
|
2264
|
+
2. It's not a known action
|
|
2265
|
+
3. It's not a known builtin like String operations
|
|
2266
|
+
4. It's not a known Pydantic model or dataclass (those are allowed)
|
|
2267
|
+
|
|
2268
|
+
This is a heuristic - we can't perfectly distinguish constructors
|
|
2269
|
+
from functions without full type information.
|
|
2270
|
+
"""
|
|
2271
|
+
# Check if first letter is uppercase (class naming convention)
|
|
2272
|
+
if not func_name or not func_name[0].isupper():
|
|
2273
|
+
return False
|
|
2274
|
+
|
|
2275
|
+
# If it's a known action, it's not a constructor
|
|
2276
|
+
if func_name in self._action_defs:
|
|
2277
|
+
return False
|
|
2278
|
+
|
|
2279
|
+
# If it's a known Pydantic model or dataclass, allow it
|
|
2280
|
+
# (it will be converted to a dict expression)
|
|
2281
|
+
if func_name in self._model_defs:
|
|
2282
|
+
return False
|
|
2283
|
+
|
|
2284
|
+
# Common builtins that start with uppercase but aren't constructors
|
|
2285
|
+
# (these are rarely used in workflow code but let's be safe)
|
|
2286
|
+
builtin_exceptions = {"True", "False", "None", "Ellipsis"}
|
|
2287
|
+
if func_name in builtin_exceptions:
|
|
2288
|
+
return False
|
|
2289
|
+
|
|
2290
|
+
return True
|
|
2291
|
+
|
|
2292
|
+
def _is_model_constructor(self, node: ast.expr) -> Optional[str]:
|
|
2293
|
+
"""Check if an expression is a Pydantic model or dataclass constructor call.
|
|
2294
|
+
|
|
2295
|
+
Returns the model name if it is, None otherwise.
|
|
2296
|
+
"""
|
|
2297
|
+
if not isinstance(node, ast.Call):
|
|
2298
|
+
return None
|
|
2299
|
+
|
|
2300
|
+
func_name = self._get_constructor_name(node.func)
|
|
2301
|
+
if func_name and func_name in self._model_defs:
|
|
2302
|
+
return func_name
|
|
2303
|
+
|
|
2304
|
+
return None
|
|
2305
|
+
|
|
2306
|
+
def _convert_model_constructor_to_dict(self, node: ast.Call, model_name: str) -> ir.Expr:
|
|
2307
|
+
"""Convert a Pydantic model or dataclass constructor call to a dict expression.
|
|
2308
|
+
|
|
2309
|
+
For example:
|
|
2310
|
+
MyModel(field1=value1, field2=value2)
|
|
2311
|
+
becomes:
|
|
2312
|
+
{"field1": value1, "field2": value2}
|
|
2313
|
+
|
|
2314
|
+
Default values from the model definition are included for fields not
|
|
2315
|
+
explicitly provided in the constructor call.
|
|
2316
|
+
"""
|
|
2317
|
+
model_def = self._model_defs[model_name]
|
|
2318
|
+
entries: List[ir.DictEntry] = []
|
|
2319
|
+
|
|
2320
|
+
# Track which fields were explicitly provided
|
|
2321
|
+
provided_fields: Set[str] = set()
|
|
2322
|
+
|
|
2323
|
+
# First, add all explicitly provided kwargs
|
|
2324
|
+
for kw in node.keywords:
|
|
2325
|
+
if kw.arg is None:
|
|
2326
|
+
# **kwargs expansion - not supported
|
|
2327
|
+
line = getattr(node, "lineno", None)
|
|
2328
|
+
col = getattr(node, "col_offset", None)
|
|
2329
|
+
raise UnsupportedPatternError(
|
|
2330
|
+
f"Model constructor '{model_name}' with **kwargs is not supported",
|
|
2331
|
+
"Use explicit keyword arguments instead of **kwargs.",
|
|
2332
|
+
line=line,
|
|
2333
|
+
col=col,
|
|
2334
|
+
)
|
|
2335
|
+
|
|
2336
|
+
provided_fields.add(kw.arg)
|
|
2337
|
+
key_expr = ir.Expr()
|
|
2338
|
+
key_literal = ir.Literal()
|
|
2339
|
+
key_literal.string_value = kw.arg
|
|
2340
|
+
key_expr.literal.CopyFrom(key_literal)
|
|
2341
|
+
|
|
2342
|
+
value_expr = self._expr_to_ir_with_model_coercion(kw.value)
|
|
2343
|
+
if value_expr is None:
|
|
2344
|
+
# If we can't convert the value, we need to raise an error
|
|
2345
|
+
line = getattr(node, "lineno", None)
|
|
2346
|
+
col = getattr(node, "col_offset", None)
|
|
2347
|
+
raise UnsupportedPatternError(
|
|
2348
|
+
f"Cannot convert value for field '{kw.arg}' in '{model_name}'",
|
|
2349
|
+
"Use simpler expressions (literals, variables, dicts, lists).",
|
|
2350
|
+
line=line,
|
|
2351
|
+
col=col,
|
|
2352
|
+
)
|
|
2353
|
+
|
|
2354
|
+
entries.append(ir.DictEntry(key=key_expr, value=value_expr))
|
|
2355
|
+
|
|
2356
|
+
# Handle positional arguments (dataclasses support this)
|
|
2357
|
+
if node.args:
|
|
2358
|
+
# For dataclasses, positional args map to fields in order
|
|
2359
|
+
field_names = list(model_def.fields.keys())
|
|
2360
|
+
for i, arg in enumerate(node.args):
|
|
2361
|
+
if i >= len(field_names):
|
|
2362
|
+
line = getattr(node, "lineno", None)
|
|
2363
|
+
col = getattr(node, "col_offset", None)
|
|
2364
|
+
raise UnsupportedPatternError(
|
|
2365
|
+
f"Too many positional arguments for '{model_name}'",
|
|
2366
|
+
"Use keyword arguments for clarity.",
|
|
2367
|
+
line=line,
|
|
2368
|
+
col=col,
|
|
2369
|
+
)
|
|
2370
|
+
|
|
2371
|
+
field_name = field_names[i]
|
|
2372
|
+
provided_fields.add(field_name)
|
|
2373
|
+
|
|
2374
|
+
key_expr = ir.Expr()
|
|
2375
|
+
key_literal = ir.Literal()
|
|
2376
|
+
key_literal.string_value = field_name
|
|
2377
|
+
key_expr.literal.CopyFrom(key_literal)
|
|
2378
|
+
|
|
2379
|
+
value_expr = self._expr_to_ir_with_model_coercion(arg)
|
|
2380
|
+
if value_expr is None:
|
|
2381
|
+
line = getattr(node, "lineno", None)
|
|
2382
|
+
col = getattr(node, "col_offset", None)
|
|
2383
|
+
raise UnsupportedPatternError(
|
|
2384
|
+
f"Cannot convert positional argument for field '{field_name}' in '{model_name}'",
|
|
2385
|
+
"Use simpler expressions (literals, variables, dicts, lists).",
|
|
2386
|
+
line=line,
|
|
2387
|
+
col=col,
|
|
2388
|
+
)
|
|
2389
|
+
|
|
2390
|
+
entries.append(ir.DictEntry(key=key_expr, value=value_expr))
|
|
2391
|
+
|
|
2392
|
+
# Add default values for fields not explicitly provided
|
|
2393
|
+
for field_name, field_def in model_def.fields.items():
|
|
2394
|
+
if field_name in provided_fields:
|
|
2395
|
+
continue
|
|
2396
|
+
|
|
2397
|
+
if field_def.has_default:
|
|
2398
|
+
key_expr = ir.Expr()
|
|
2399
|
+
key_literal = ir.Literal()
|
|
2400
|
+
key_literal.string_value = field_name
|
|
2401
|
+
key_expr.literal.CopyFrom(key_literal)
|
|
2402
|
+
|
|
2403
|
+
# Convert the default value to an IR literal
|
|
2404
|
+
default_literal = _constant_to_literal(field_def.default_value)
|
|
2405
|
+
if default_literal is None:
|
|
2406
|
+
# Can't serialize this default - skip it
|
|
2407
|
+
# (it's probably a complex object like a list factory)
|
|
2408
|
+
continue
|
|
2409
|
+
|
|
2410
|
+
value_expr = ir.Expr()
|
|
2411
|
+
value_expr.literal.CopyFrom(default_literal)
|
|
2412
|
+
|
|
2413
|
+
entries.append(ir.DictEntry(key=key_expr, value=value_expr))
|
|
2414
|
+
|
|
2415
|
+
result = ir.Expr(span=_make_span(node))
|
|
2416
|
+
result.dict.CopyFrom(ir.DictExpr(entries=entries))
|
|
2417
|
+
return result
|
|
2418
|
+
|
|
2419
|
+
def _check_non_action_await(self, node: ast.Await) -> None:
|
|
2420
|
+
"""Check if an await is for a non-action function.
|
|
2421
|
+
|
|
2422
|
+
Note: We can only reliably detect non-action awaits for functions defined
|
|
2423
|
+
in the same module. Actions imported from other modules will pass through
|
|
2424
|
+
and may fail at runtime if they're not actually actions.
|
|
2425
|
+
|
|
2426
|
+
For now, we only check against common builtins and known non-action patterns.
|
|
2427
|
+
A runtime check will catch functions that aren't registered actions.
|
|
2428
|
+
"""
|
|
2429
|
+
awaited = node.value
|
|
2430
|
+
if not isinstance(awaited, ast.Call):
|
|
2431
|
+
return
|
|
2432
|
+
|
|
2433
|
+
# Skip special cases that are handled elsewhere
|
|
2434
|
+
if self._is_run_action_call(awaited):
|
|
2435
|
+
return
|
|
2436
|
+
if self._is_asyncio_sleep_call(awaited):
|
|
2437
|
+
return
|
|
2438
|
+
if self._is_asyncio_gather_call(awaited):
|
|
2439
|
+
return
|
|
2440
|
+
|
|
2441
|
+
# Get the function name
|
|
2442
|
+
func_name = None
|
|
2443
|
+
if isinstance(awaited.func, ast.Name):
|
|
2444
|
+
func_name = awaited.func.id
|
|
2445
|
+
elif isinstance(awaited.func, ast.Attribute):
|
|
2446
|
+
func_name = awaited.func.attr
|
|
2447
|
+
|
|
2448
|
+
if not func_name:
|
|
2449
|
+
return
|
|
2450
|
+
|
|
2451
|
+
# Only raise error for functions defined in THIS module that we know
|
|
2452
|
+
# are NOT actions (i.e., async functions without @action decorator)
|
|
2453
|
+
# We can't reliably detect imported non-actions without full type info.
|
|
2454
|
+
#
|
|
2455
|
+
# The check works by looking at _module_functions which contains
|
|
2456
|
+
# functions defined in the same module as the workflow.
|
|
2457
|
+
if func_name in getattr(self, "_module_functions", set()):
|
|
2458
|
+
if func_name not in self._action_defs:
|
|
2459
|
+
line = getattr(node, "lineno", None)
|
|
2460
|
+
col = getattr(node, "col_offset", None)
|
|
2461
|
+
raise UnsupportedPatternError(
|
|
2462
|
+
f"Awaiting non-action function '{func_name}()' is not supported",
|
|
2463
|
+
RECOMMENDATIONS["non_action_call"],
|
|
2464
|
+
line=line,
|
|
2465
|
+
col=col,
|
|
2466
|
+
)
|
|
2467
|
+
|
|
2468
|
+
def _check_sync_function_call(self, node: ast.Call) -> None:
|
|
2469
|
+
"""Check for synchronous function calls that should be in actions.
|
|
2470
|
+
|
|
2471
|
+
Common patterns like len(), str(), etc. are not supported in workflow code.
|
|
2472
|
+
"""
|
|
2473
|
+
func_name = None
|
|
2474
|
+
if isinstance(node.func, ast.Name):
|
|
2475
|
+
func_name = node.func.id
|
|
2476
|
+
elif isinstance(node.func, ast.Attribute):
|
|
2477
|
+
# Method calls on objects - check the method name
|
|
2478
|
+
func_name = node.func.attr
|
|
2479
|
+
|
|
2480
|
+
if not func_name:
|
|
2481
|
+
return
|
|
2482
|
+
|
|
2483
|
+
# Builtins that users commonly try to use
|
|
2484
|
+
common_builtins = {
|
|
2485
|
+
"len",
|
|
2486
|
+
"str",
|
|
2487
|
+
"int",
|
|
2488
|
+
"float",
|
|
2489
|
+
"bool",
|
|
2490
|
+
"list",
|
|
2491
|
+
"dict",
|
|
2492
|
+
"set",
|
|
2493
|
+
"tuple",
|
|
2494
|
+
"sum",
|
|
2495
|
+
"min",
|
|
2496
|
+
"max",
|
|
2497
|
+
"sorted",
|
|
2498
|
+
"reversed",
|
|
2499
|
+
"enumerate",
|
|
2500
|
+
"zip",
|
|
2501
|
+
"map",
|
|
2502
|
+
"filter",
|
|
2503
|
+
"range",
|
|
2504
|
+
"abs",
|
|
2505
|
+
"round",
|
|
2506
|
+
"print",
|
|
2507
|
+
"type",
|
|
2508
|
+
"isinstance",
|
|
2509
|
+
"hasattr",
|
|
2510
|
+
"getattr",
|
|
2511
|
+
"setattr",
|
|
2512
|
+
"open",
|
|
2513
|
+
"format",
|
|
2514
|
+
}
|
|
2515
|
+
|
|
2516
|
+
if func_name in common_builtins:
|
|
2517
|
+
line = getattr(node, "lineno", None)
|
|
2518
|
+
col = getattr(node, "col_offset", None)
|
|
2519
|
+
raise UnsupportedPatternError(
|
|
2520
|
+
f"Calling built-in function '{func_name}()' directly is not supported",
|
|
2521
|
+
RECOMMENDATIONS["builtin_call"],
|
|
2522
|
+
line=line,
|
|
2523
|
+
col=col,
|
|
2524
|
+
)
|
|
2525
|
+
|
|
2526
|
+
def _extract_action_call(self, node: ast.expr) -> Optional[ir.ActionCall]:
|
|
2527
|
+
"""Extract an action call from an expression if present.
|
|
2528
|
+
|
|
2529
|
+
Also validates that awaited calls are actually @action decorated functions.
|
|
2530
|
+
Raises UnsupportedPatternError if awaiting a non-action function.
|
|
2531
|
+
"""
|
|
2532
|
+
if not isinstance(node, ast.Await):
|
|
2533
|
+
return None
|
|
2534
|
+
|
|
2535
|
+
awaited = node.value
|
|
2536
|
+
# Handle self.run_action(...) wrapper
|
|
2537
|
+
if isinstance(awaited, ast.Call):
|
|
2538
|
+
if self._is_run_action_call(awaited):
|
|
2539
|
+
# Extract the actual action call from run_action
|
|
2540
|
+
if awaited.args:
|
|
2541
|
+
action_call = self._extract_action_call_from_awaitable(awaited.args[0])
|
|
2542
|
+
if action_call:
|
|
2543
|
+
# Extract policies from run_action kwargs (retry, timeout)
|
|
2544
|
+
self._extract_policies_from_run_action(awaited, action_call)
|
|
2545
|
+
return action_call
|
|
2546
|
+
# Check for asyncio.sleep() - convert to @sleep action
|
|
2547
|
+
if self._is_asyncio_sleep_call(awaited):
|
|
2548
|
+
return self._convert_asyncio_sleep_to_action(awaited)
|
|
2549
|
+
# Try to extract as action call
|
|
2550
|
+
action_call = self._extract_action_call_from_call(awaited)
|
|
2551
|
+
if action_call:
|
|
2552
|
+
return action_call
|
|
2553
|
+
|
|
2554
|
+
# If we get here, it's an await of a non-action function
|
|
2555
|
+
self._check_non_action_await(node)
|
|
2556
|
+
return None
|
|
2557
|
+
|
|
2558
|
+
return None
|
|
2559
|
+
|
|
2560
|
+
def _is_run_action_call(self, node: ast.Call) -> bool:
|
|
2561
|
+
"""Check if this is a self.run_action(...) call."""
|
|
2562
|
+
if isinstance(node.func, ast.Attribute):
|
|
2563
|
+
return node.func.attr == "run_action"
|
|
2564
|
+
return False
|
|
2565
|
+
|
|
2566
|
+
def _extract_policies_from_run_action(
|
|
2567
|
+
self, run_action_call: ast.Call, action_call: ir.ActionCall
|
|
2568
|
+
) -> None:
|
|
2569
|
+
"""Extract retry and timeout policies from run_action kwargs.
|
|
2570
|
+
|
|
2571
|
+
Parses patterns like:
|
|
2572
|
+
- self.run_action(action(), retry=RetryPolicy(attempts=3))
|
|
2573
|
+
- self.run_action(action(), timeout=timedelta(seconds=30))
|
|
2574
|
+
- self.run_action(action(), timeout=60)
|
|
2575
|
+
"""
|
|
2576
|
+
for kw in run_action_call.keywords:
|
|
2577
|
+
if kw.arg == "retry":
|
|
2578
|
+
retry_policy = self._parse_retry_policy(kw.value)
|
|
2579
|
+
if retry_policy:
|
|
2580
|
+
policy_bracket = ir.PolicyBracket()
|
|
2581
|
+
policy_bracket.retry.CopyFrom(retry_policy)
|
|
2582
|
+
action_call.policies.append(policy_bracket)
|
|
2583
|
+
elif kw.arg == "timeout":
|
|
2584
|
+
timeout_policy = self._parse_timeout_policy(kw.value)
|
|
2585
|
+
if timeout_policy:
|
|
2586
|
+
policy_bracket = ir.PolicyBracket()
|
|
2587
|
+
policy_bracket.timeout.CopyFrom(timeout_policy)
|
|
2588
|
+
action_call.policies.append(policy_bracket)
|
|
2589
|
+
|
|
2590
|
+
def _parse_retry_policy(self, node: ast.expr) -> Optional[ir.RetryPolicy]:
|
|
2591
|
+
"""Parse a RetryPolicy(...) call into IR.
|
|
2592
|
+
|
|
2593
|
+
Supports:
|
|
2594
|
+
- RetryPolicy(attempts=3)
|
|
2595
|
+
- RetryPolicy(attempts=3, exception_types=["ValueError"])
|
|
2596
|
+
- RetryPolicy(attempts=3, backoff_seconds=5)
|
|
2597
|
+
- self.retry_policy (instance attribute reference)
|
|
2598
|
+
"""
|
|
2599
|
+
# Handle self.attr pattern - look up in instance attrs
|
|
2600
|
+
if (
|
|
2601
|
+
isinstance(node, ast.Attribute)
|
|
2602
|
+
and isinstance(node.value, ast.Name)
|
|
2603
|
+
and node.value.id == "self"
|
|
2604
|
+
):
|
|
2605
|
+
attr_name = node.attr
|
|
2606
|
+
if attr_name in self._instance_attrs:
|
|
2607
|
+
return self._parse_retry_policy(self._instance_attrs[attr_name])
|
|
2608
|
+
return None
|
|
2609
|
+
|
|
2610
|
+
if not isinstance(node, ast.Call):
|
|
2611
|
+
return None
|
|
2612
|
+
|
|
2613
|
+
# Check if it's a RetryPolicy call
|
|
2614
|
+
func_name = None
|
|
2615
|
+
if isinstance(node.func, ast.Name):
|
|
2616
|
+
func_name = node.func.id
|
|
2617
|
+
elif isinstance(node.func, ast.Attribute):
|
|
2618
|
+
func_name = node.func.attr
|
|
2619
|
+
|
|
2620
|
+
if func_name != "RetryPolicy":
|
|
2621
|
+
return None
|
|
2622
|
+
|
|
2623
|
+
policy = ir.RetryPolicy()
|
|
2624
|
+
|
|
2625
|
+
for kw in node.keywords:
|
|
2626
|
+
if kw.arg == "attempts" and isinstance(kw.value, ast.Constant):
|
|
2627
|
+
# attempts means total executions, max_retries means retries after first attempt
|
|
2628
|
+
# So attempts=1 -> max_retries=0 (no retries), attempts=3 -> max_retries=2
|
|
2629
|
+
policy.max_retries = kw.value.value - 1
|
|
2630
|
+
elif kw.arg == "exception_types" and isinstance(kw.value, ast.List):
|
|
2631
|
+
for elt in kw.value.elts:
|
|
2632
|
+
if isinstance(elt, ast.Constant) and isinstance(elt.value, str):
|
|
2633
|
+
policy.exception_types.append(elt.value)
|
|
2634
|
+
elif kw.arg == "backoff_seconds" and isinstance(kw.value, ast.Constant):
|
|
2635
|
+
policy.backoff.seconds = int(kw.value.value)
|
|
2636
|
+
|
|
2637
|
+
return policy
|
|
2638
|
+
|
|
2639
|
+
def _parse_timeout_policy(self, node: ast.expr) -> Optional[ir.TimeoutPolicy]:
|
|
2640
|
+
"""Parse a timeout value into IR.
|
|
2641
|
+
|
|
2642
|
+
Supports:
|
|
2643
|
+
- timeout=60 (int seconds)
|
|
2644
|
+
- timeout=30.5 (float seconds)
|
|
2645
|
+
- timeout=timedelta(seconds=30)
|
|
2646
|
+
- timeout=timedelta(minutes=2)
|
|
2647
|
+
- self.timeout (instance attribute reference)
|
|
2648
|
+
"""
|
|
2649
|
+
# Handle self.attr pattern - look up in instance attrs
|
|
2650
|
+
if (
|
|
2651
|
+
isinstance(node, ast.Attribute)
|
|
2652
|
+
and isinstance(node.value, ast.Name)
|
|
2653
|
+
and node.value.id == "self"
|
|
2654
|
+
):
|
|
2655
|
+
attr_name = node.attr
|
|
2656
|
+
if attr_name in self._instance_attrs:
|
|
2657
|
+
return self._parse_timeout_policy(self._instance_attrs[attr_name])
|
|
2658
|
+
return None
|
|
2659
|
+
|
|
2660
|
+
policy = ir.TimeoutPolicy()
|
|
2661
|
+
|
|
2662
|
+
# Direct numeric value (seconds)
|
|
2663
|
+
if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)):
|
|
2664
|
+
policy.timeout.seconds = int(node.value)
|
|
2665
|
+
return policy
|
|
2666
|
+
|
|
2667
|
+
# timedelta(...) call
|
|
2668
|
+
if isinstance(node, ast.Call):
|
|
2669
|
+
func_name = None
|
|
2670
|
+
if isinstance(node.func, ast.Name):
|
|
2671
|
+
func_name = node.func.id
|
|
2672
|
+
elif isinstance(node.func, ast.Attribute):
|
|
2673
|
+
func_name = node.func.attr
|
|
2674
|
+
|
|
2675
|
+
if func_name == "timedelta":
|
|
2676
|
+
total_seconds = 0
|
|
2677
|
+
for kw in node.keywords:
|
|
2678
|
+
if isinstance(kw.value, ast.Constant):
|
|
2679
|
+
val = kw.value.value
|
|
2680
|
+
if kw.arg == "seconds":
|
|
2681
|
+
total_seconds += int(val)
|
|
2682
|
+
elif kw.arg == "minutes":
|
|
2683
|
+
total_seconds += int(val) * 60
|
|
2684
|
+
elif kw.arg == "hours":
|
|
2685
|
+
total_seconds += int(val) * 3600
|
|
2686
|
+
elif kw.arg == "days":
|
|
2687
|
+
total_seconds += int(val) * 86400
|
|
2688
|
+
policy.timeout.seconds = total_seconds
|
|
2689
|
+
return policy
|
|
2690
|
+
|
|
2691
|
+
return None
|
|
2692
|
+
|
|
2693
|
+
def _is_asyncio_sleep_call(self, node: ast.Call) -> bool:
|
|
2694
|
+
"""Check if this is an asyncio.sleep(...) call.
|
|
2695
|
+
|
|
2696
|
+
Supports both patterns:
|
|
2697
|
+
- import asyncio; asyncio.sleep(1)
|
|
2698
|
+
- from asyncio import sleep; sleep(1)
|
|
2699
|
+
- from asyncio import sleep as s; s(1)
|
|
2700
|
+
"""
|
|
2701
|
+
if isinstance(node.func, ast.Attribute):
|
|
2702
|
+
# asyncio.sleep(...) pattern
|
|
2703
|
+
if node.func.attr == "sleep" and isinstance(node.func.value, ast.Name):
|
|
2704
|
+
return node.func.value.id == "asyncio"
|
|
2705
|
+
elif isinstance(node.func, ast.Name):
|
|
2706
|
+
# sleep(...) pattern - check if it's imported from asyncio
|
|
2707
|
+
func_name = node.func.id
|
|
2708
|
+
if func_name in self._imported_names:
|
|
2709
|
+
imported = self._imported_names[func_name]
|
|
2710
|
+
return imported.module == "asyncio" and imported.original_name == "sleep"
|
|
2711
|
+
return False
|
|
2712
|
+
|
|
2713
|
+
def _convert_asyncio_sleep_to_action(self, node: ast.Call) -> ir.ActionCall:
|
|
2714
|
+
"""Convert asyncio.sleep(duration) to @sleep(duration=X) action call.
|
|
2715
|
+
|
|
2716
|
+
This creates a built-in sleep action that the scheduler handles as a
|
|
2717
|
+
durable sleep - stored in the DB with a future scheduled_at time.
|
|
2718
|
+
"""
|
|
2719
|
+
action_call = ir.ActionCall(action_name="sleep")
|
|
2720
|
+
|
|
2721
|
+
# Extract duration argument (positional or keyword)
|
|
2722
|
+
if node.args:
|
|
2723
|
+
# asyncio.sleep(1) - positional
|
|
2724
|
+
expr = self._expr_to_ir_with_model_coercion(node.args[0])
|
|
2725
|
+
if expr:
|
|
2726
|
+
action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
|
|
2727
|
+
elif node.keywords:
|
|
2728
|
+
# asyncio.sleep(seconds=1) - keyword (less common)
|
|
2729
|
+
for kw in node.keywords:
|
|
2730
|
+
if kw.arg in ("seconds", "delay", "duration"):
|
|
2731
|
+
expr = self._expr_to_ir_with_model_coercion(kw.value)
|
|
2732
|
+
if expr:
|
|
2733
|
+
action_call.kwargs.append(ir.Kwarg(name="duration", value=expr))
|
|
2734
|
+
break
|
|
2735
|
+
|
|
2736
|
+
return action_call
|
|
2737
|
+
|
|
2738
|
+
def _is_asyncio_gather_call(self, node: ast.Call) -> bool:
|
|
2739
|
+
"""Check if this is an asyncio.gather(...) call.
|
|
2740
|
+
|
|
2741
|
+
Supports both patterns:
|
|
2742
|
+
- import asyncio; asyncio.gather(a(), b())
|
|
2743
|
+
- from asyncio import gather; gather(a(), b())
|
|
2744
|
+
- from asyncio import gather as g; g(a(), b())
|
|
2745
|
+
"""
|
|
2746
|
+
if isinstance(node.func, ast.Attribute):
|
|
2747
|
+
# asyncio.gather(...) pattern
|
|
2748
|
+
if node.func.attr == "gather" and isinstance(node.func.value, ast.Name):
|
|
2749
|
+
return node.func.value.id == "asyncio"
|
|
2750
|
+
elif isinstance(node.func, ast.Name):
|
|
2751
|
+
# gather(...) pattern - check if it's imported from asyncio
|
|
2752
|
+
func_name = node.func.id
|
|
2753
|
+
if func_name in self._imported_names:
|
|
2754
|
+
imported = self._imported_names[func_name]
|
|
2755
|
+
return imported.module == "asyncio" and imported.original_name == "gather"
|
|
2756
|
+
return False
|
|
2757
|
+
|
|
2758
|
+
def _convert_asyncio_gather(
|
|
2759
|
+
self, node: ast.Call
|
|
2760
|
+
) -> Optional[Union[ir.ParallelExpr, ir.SpreadExpr]]:
|
|
2761
|
+
"""Convert asyncio.gather(...) to ParallelExpr or SpreadExpr IR.
|
|
2762
|
+
|
|
2763
|
+
Handles two patterns:
|
|
2764
|
+
1. Static gather: asyncio.gather(a(), b(), c()) -> ParallelExpr
|
|
2765
|
+
2. Spread gather: asyncio.gather(*[action(x) for x in items]) -> SpreadExpr
|
|
2766
|
+
|
|
2767
|
+
Args:
|
|
2768
|
+
node: The asyncio.gather() Call node
|
|
2769
|
+
|
|
2770
|
+
Returns:
|
|
2771
|
+
A ParallelExpr, SpreadExpr, or None if conversion fails.
|
|
2772
|
+
"""
|
|
2773
|
+
# Check for starred expressions - spread pattern
|
|
2774
|
+
if len(node.args) == 1 and isinstance(node.args[0], ast.Starred):
|
|
2775
|
+
starred = node.args[0]
|
|
2776
|
+
# Only list comprehensions are supported for spread
|
|
2777
|
+
if isinstance(starred.value, ast.ListComp):
|
|
2778
|
+
return self._convert_listcomp_to_spread_expr(starred.value)
|
|
2779
|
+
else:
|
|
2780
|
+
# Spreading a variable or other expression is not supported
|
|
2781
|
+
line = getattr(node, "lineno", None)
|
|
2782
|
+
col = getattr(node, "col_offset", None)
|
|
2783
|
+
if isinstance(starred.value, ast.Name):
|
|
2784
|
+
var_name = starred.value.id
|
|
2785
|
+
raise UnsupportedPatternError(
|
|
2786
|
+
f"Spreading variable '{var_name}' in asyncio.gather() is not supported",
|
|
2787
|
+
RECOMMENDATIONS["gather_variable_spread"],
|
|
2788
|
+
line=line,
|
|
2789
|
+
col=col,
|
|
2790
|
+
)
|
|
2791
|
+
else:
|
|
2792
|
+
raise UnsupportedPatternError(
|
|
2793
|
+
"Spreading non-list-comprehension expressions in asyncio.gather() is not supported",
|
|
2794
|
+
RECOMMENDATIONS["gather_variable_spread"],
|
|
2795
|
+
line=line,
|
|
2796
|
+
col=col,
|
|
2797
|
+
)
|
|
2798
|
+
|
|
2799
|
+
# Standard case: gather(a(), b(), c()) -> ParallelExpr
|
|
2800
|
+
parallel = ir.ParallelExpr()
|
|
2801
|
+
|
|
2802
|
+
# Each argument to gather() should be an action call
|
|
2803
|
+
for arg in node.args:
|
|
2804
|
+
call = self._convert_gather_arg_to_call(arg)
|
|
2805
|
+
if call:
|
|
2806
|
+
parallel.calls.append(call)
|
|
2807
|
+
|
|
2808
|
+
# Only return if we have calls
|
|
2809
|
+
if not parallel.calls:
|
|
2810
|
+
return None
|
|
2811
|
+
|
|
2812
|
+
return parallel
|
|
2813
|
+
|
|
2814
|
+
def _convert_listcomp_to_spread_expr(self, listcomp: ast.ListComp) -> Optional[ir.SpreadExpr]:
|
|
2815
|
+
"""Convert a list comprehension to SpreadExpr IR.
|
|
2816
|
+
|
|
2817
|
+
Handles patterns like:
|
|
2818
|
+
[action(x=item) for item in collection]
|
|
2819
|
+
|
|
2820
|
+
The comprehension must have exactly one generator with no conditions,
|
|
2821
|
+
and the element must be an action call.
|
|
2822
|
+
|
|
2823
|
+
Args:
|
|
2824
|
+
listcomp: The ListComp AST node
|
|
2825
|
+
|
|
2826
|
+
Returns:
|
|
2827
|
+
A SpreadExpr, or None if conversion fails.
|
|
2828
|
+
"""
|
|
2829
|
+
# Only support simple list comprehensions with one generator
|
|
2830
|
+
if len(listcomp.generators) != 1:
|
|
2831
|
+
line = getattr(listcomp, "lineno", None)
|
|
2832
|
+
col = getattr(listcomp, "col_offset", None)
|
|
2833
|
+
raise UnsupportedPatternError(
|
|
2834
|
+
"Spread pattern only supports a single loop variable",
|
|
2835
|
+
"Use a simple list comprehension: [action(x) for x in items]",
|
|
2836
|
+
line=line,
|
|
2837
|
+
col=col,
|
|
2838
|
+
)
|
|
2839
|
+
|
|
2840
|
+
gen = listcomp.generators[0]
|
|
2841
|
+
|
|
2842
|
+
# Check for conditions - not supported
|
|
2843
|
+
if gen.ifs:
|
|
2844
|
+
line = getattr(listcomp, "lineno", None)
|
|
2845
|
+
col = getattr(listcomp, "col_offset", None)
|
|
2846
|
+
raise UnsupportedPatternError(
|
|
2847
|
+
"Spread pattern does not support conditions in list comprehension",
|
|
2848
|
+
"Remove the 'if' clause from the comprehension",
|
|
2849
|
+
line=line,
|
|
2850
|
+
col=col,
|
|
2851
|
+
)
|
|
2852
|
+
|
|
2853
|
+
# Get the loop variable name
|
|
2854
|
+
if not isinstance(gen.target, ast.Name):
|
|
2855
|
+
line = getattr(listcomp, "lineno", None)
|
|
2856
|
+
col = getattr(listcomp, "col_offset", None)
|
|
2857
|
+
raise UnsupportedPatternError(
|
|
2858
|
+
"Spread pattern requires a simple loop variable",
|
|
2859
|
+
"Use a simple variable: [action(x) for x in items]",
|
|
2860
|
+
line=line,
|
|
2861
|
+
col=col,
|
|
2862
|
+
)
|
|
2863
|
+
loop_var = gen.target.id
|
|
2864
|
+
|
|
2865
|
+
# Get the collection expression
|
|
2866
|
+
collection_expr = self._expr_to_ir_with_model_coercion(gen.iter)
|
|
2867
|
+
if not collection_expr:
|
|
2868
|
+
line = getattr(listcomp, "lineno", None)
|
|
2869
|
+
col = getattr(listcomp, "col_offset", None)
|
|
2870
|
+
raise UnsupportedPatternError(
|
|
2871
|
+
"Could not convert collection expression in spread pattern",
|
|
2872
|
+
"Ensure the collection is a simple variable or expression",
|
|
2873
|
+
line=line,
|
|
2874
|
+
col=col,
|
|
2875
|
+
)
|
|
2876
|
+
|
|
2877
|
+
# The element must be an action call
|
|
2878
|
+
if not isinstance(listcomp.elt, ast.Call):
|
|
2879
|
+
line = getattr(listcomp, "lineno", None)
|
|
2880
|
+
col = getattr(listcomp, "col_offset", None)
|
|
2881
|
+
raise UnsupportedPatternError(
|
|
2882
|
+
"Spread pattern requires an action call in the list comprehension",
|
|
2883
|
+
"Use: [action(x=item) for item in items]",
|
|
2884
|
+
line=line,
|
|
2885
|
+
col=col,
|
|
2886
|
+
)
|
|
2887
|
+
|
|
2888
|
+
action_call = self._extract_action_call_from_call(listcomp.elt)
|
|
2889
|
+
if not action_call:
|
|
2890
|
+
line = getattr(listcomp, "lineno", None)
|
|
2891
|
+
col = getattr(listcomp, "col_offset", None)
|
|
2892
|
+
raise UnsupportedPatternError(
|
|
2893
|
+
"Spread pattern element must be an @action call",
|
|
2894
|
+
"Ensure the function is decorated with @action",
|
|
2895
|
+
line=line,
|
|
2896
|
+
col=col,
|
|
2897
|
+
)
|
|
2898
|
+
|
|
2899
|
+
# Build the SpreadExpr
|
|
2900
|
+
spread = ir.SpreadExpr()
|
|
2901
|
+
spread.collection.CopyFrom(collection_expr)
|
|
2902
|
+
spread.loop_var = loop_var
|
|
2903
|
+
spread.action.CopyFrom(action_call)
|
|
2904
|
+
|
|
2905
|
+
return spread
|
|
2906
|
+
|
|
2907
|
+
def _convert_gather_arg_to_call(self, node: ast.expr) -> Optional[ir.Call]:
|
|
2908
|
+
"""Convert a gather argument to an IR Call.
|
|
2909
|
+
|
|
2910
|
+
Handles both action calls and regular function calls.
|
|
2911
|
+
"""
|
|
2912
|
+
if not isinstance(node, ast.Call):
|
|
2913
|
+
return None
|
|
2914
|
+
|
|
2915
|
+
# Try to extract as an action call first
|
|
2916
|
+
action_call = self._extract_action_call_from_call(node)
|
|
2917
|
+
if action_call:
|
|
2918
|
+
call = ir.Call()
|
|
2919
|
+
call.action.CopyFrom(action_call)
|
|
2920
|
+
return call
|
|
2921
|
+
|
|
2922
|
+
# Fall back to regular function call
|
|
2923
|
+
func_call = self._convert_to_function_call(node)
|
|
2924
|
+
if func_call:
|
|
2925
|
+
call = ir.Call()
|
|
2926
|
+
call.function.CopyFrom(func_call)
|
|
2927
|
+
return call
|
|
2928
|
+
|
|
2929
|
+
return None
|
|
2930
|
+
|
|
2931
|
+
def _convert_to_function_call(self, node: ast.Call) -> Optional[ir.FunctionCall]:
|
|
2932
|
+
"""Convert an AST Call to IR FunctionCall."""
|
|
2933
|
+
func_name = self._get_func_name(node.func)
|
|
2934
|
+
if not func_name:
|
|
2935
|
+
return None
|
|
2936
|
+
|
|
2937
|
+
fn_call = ir.FunctionCall(name=func_name)
|
|
2938
|
+
global_function = _global_function_for_call(func_name, node)
|
|
2939
|
+
if global_function is not None:
|
|
2940
|
+
fn_call.global_function = global_function
|
|
2941
|
+
|
|
2942
|
+
# Add positional args
|
|
2943
|
+
for arg in node.args:
|
|
2944
|
+
expr = self._expr_to_ir_with_model_coercion(arg)
|
|
2945
|
+
if expr:
|
|
2946
|
+
fn_call.args.append(expr)
|
|
2947
|
+
|
|
2948
|
+
# Add keyword args
|
|
2949
|
+
for kw in node.keywords:
|
|
2950
|
+
if kw.arg:
|
|
2951
|
+
expr = self._expr_to_ir_with_model_coercion(kw.value)
|
|
2952
|
+
if expr:
|
|
2953
|
+
fn_call.kwargs.append(ir.Kwarg(name=kw.arg, value=expr))
|
|
2954
|
+
|
|
2955
|
+
# Fill in missing kwargs with default values from function signature
|
|
2956
|
+
self._fill_default_kwargs_for_expr(fn_call)
|
|
2957
|
+
|
|
2958
|
+
return fn_call
|
|
2959
|
+
|
|
2960
|
+
def _get_func_name(self, node: ast.expr) -> Optional[str]:
|
|
2961
|
+
"""Get function name from a func node."""
|
|
2962
|
+
if isinstance(node, ast.Name):
|
|
2963
|
+
return node.id
|
|
2964
|
+
elif isinstance(node, ast.Attribute):
|
|
2965
|
+
# Handle chained attributes like obj.method
|
|
2966
|
+
parts = []
|
|
2967
|
+
current = node
|
|
2968
|
+
while isinstance(current, ast.Attribute):
|
|
2969
|
+
parts.append(current.attr)
|
|
2970
|
+
current = current.value
|
|
2971
|
+
if isinstance(current, ast.Name):
|
|
2972
|
+
parts.append(current.id)
|
|
2973
|
+
name = ".".join(reversed(parts))
|
|
2974
|
+
if name.startswith("self."):
|
|
2975
|
+
return name[5:]
|
|
2976
|
+
return name
|
|
2977
|
+
return None
|
|
2978
|
+
|
|
2979
|
+
def _convert_model_constructor_if_needed(self, node: ast.Call) -> Optional[ir.Expr]:
|
|
2980
|
+
model_name = self._is_model_constructor(node)
|
|
2981
|
+
if model_name:
|
|
2982
|
+
return self._convert_model_constructor_to_dict(node, model_name)
|
|
2983
|
+
return None
|
|
2984
|
+
|
|
2985
|
+
def _resolve_enum_attribute(self, node: ast.Attribute) -> Optional[ir.Expr]:
|
|
2986
|
+
value = _resolve_enum_attribute_value(node, self._module_globals)
|
|
2987
|
+
if value is None:
|
|
2988
|
+
return None
|
|
2989
|
+
literal = _constant_to_literal(value)
|
|
2990
|
+
if literal is None:
|
|
2991
|
+
line = getattr(node, "lineno", None)
|
|
2992
|
+
col = getattr(node, "col_offset", None)
|
|
2993
|
+
raise UnsupportedPatternError(
|
|
2994
|
+
"Enum value must be a primitive literal",
|
|
2995
|
+
RECOMMENDATIONS["unsupported_literal"],
|
|
2996
|
+
line=line,
|
|
2997
|
+
col=col,
|
|
2998
|
+
)
|
|
2999
|
+
expr = ir.Expr(span=_make_span(node))
|
|
3000
|
+
expr.literal.CopyFrom(literal)
|
|
3001
|
+
return expr
|
|
3002
|
+
|
|
3003
|
+
def _expr_to_ir_with_model_coercion(self, node: ast.expr) -> Optional[ir.Expr]:
|
|
3004
|
+
"""Convert an AST expression to IR, converting model constructors to dicts."""
|
|
3005
|
+
result = _expr_to_ir(
|
|
3006
|
+
node,
|
|
3007
|
+
model_converter=self._convert_model_constructor_if_needed,
|
|
3008
|
+
enum_resolver=self._resolve_enum_attribute,
|
|
3009
|
+
)
|
|
3010
|
+
# Post-process to fill in default kwargs for function calls (recursively)
|
|
3011
|
+
if result is not None:
|
|
3012
|
+
self._fill_default_kwargs_recursive(result)
|
|
3013
|
+
return result
|
|
3014
|
+
|
|
3015
|
+
def _fill_default_kwargs_recursive(self, expr: ir.Expr) -> None:
|
|
3016
|
+
"""Recursively fill in default kwargs for all function calls in an expression."""
|
|
3017
|
+
if expr.HasField("function_call"):
|
|
3018
|
+
self._fill_default_kwargs_for_expr(expr.function_call)
|
|
3019
|
+
# Recurse into function call args and kwargs
|
|
3020
|
+
for arg in expr.function_call.args:
|
|
3021
|
+
self._fill_default_kwargs_recursive(arg)
|
|
3022
|
+
for kwarg in expr.function_call.kwargs:
|
|
3023
|
+
if kwarg.value:
|
|
3024
|
+
self._fill_default_kwargs_recursive(kwarg.value)
|
|
3025
|
+
elif expr.HasField("binary_op"):
|
|
3026
|
+
if expr.binary_op.left:
|
|
3027
|
+
self._fill_default_kwargs_recursive(expr.binary_op.left)
|
|
3028
|
+
if expr.binary_op.right:
|
|
3029
|
+
self._fill_default_kwargs_recursive(expr.binary_op.right)
|
|
3030
|
+
elif expr.HasField("unary_op"):
|
|
3031
|
+
if expr.unary_op.operand:
|
|
3032
|
+
self._fill_default_kwargs_recursive(expr.unary_op.operand)
|
|
3033
|
+
elif expr.HasField("list"):
|
|
3034
|
+
for elem in expr.list.elements:
|
|
3035
|
+
self._fill_default_kwargs_recursive(elem)
|
|
3036
|
+
elif expr.HasField("dict"):
|
|
3037
|
+
for entry in expr.dict.entries:
|
|
3038
|
+
if entry.key:
|
|
3039
|
+
self._fill_default_kwargs_recursive(entry.key)
|
|
3040
|
+
if entry.value:
|
|
3041
|
+
self._fill_default_kwargs_recursive(entry.value)
|
|
3042
|
+
elif expr.HasField("index"):
|
|
3043
|
+
if expr.index.object:
|
|
3044
|
+
self._fill_default_kwargs_recursive(expr.index.object)
|
|
3045
|
+
if expr.index.index:
|
|
3046
|
+
self._fill_default_kwargs_recursive(expr.index.index)
|
|
3047
|
+
elif expr.HasField("dot"):
|
|
3048
|
+
if expr.dot.object:
|
|
3049
|
+
self._fill_default_kwargs_recursive(expr.dot.object)
|
|
3050
|
+
|
|
3051
|
+
def _fill_default_kwargs_for_expr(self, fn_call: ir.FunctionCall) -> None:
|
|
3052
|
+
"""Fill in missing kwargs with default values from the function signature."""
|
|
3053
|
+
sig_info = self._ctx.function_signatures.get(fn_call.name)
|
|
3054
|
+
if sig_info is None:
|
|
3055
|
+
return
|
|
3056
|
+
|
|
3057
|
+
# Track which parameters are already provided
|
|
3058
|
+
provided_by_position: Set[str] = set()
|
|
3059
|
+
provided_by_kwarg: Set[str] = set()
|
|
3060
|
+
|
|
3061
|
+
# Positional args map to parameters in order
|
|
3062
|
+
for idx, _arg in enumerate(fn_call.args):
|
|
3063
|
+
if idx < len(sig_info.parameters):
|
|
3064
|
+
provided_by_position.add(sig_info.parameters[idx].name)
|
|
3065
|
+
|
|
3066
|
+
# Kwargs are named
|
|
3067
|
+
for kwarg in fn_call.kwargs:
|
|
3068
|
+
provided_by_kwarg.add(kwarg.name)
|
|
3069
|
+
|
|
3070
|
+
# Add defaults for missing parameters
|
|
3071
|
+
for param in sig_info.parameters:
|
|
3072
|
+
if param.name in provided_by_position or param.name in provided_by_kwarg:
|
|
3073
|
+
continue
|
|
3074
|
+
if not param.has_default:
|
|
3075
|
+
continue
|
|
3076
|
+
|
|
3077
|
+
# Convert the default value to an IR expression
|
|
3078
|
+
literal = _constant_to_literal(param.default_value)
|
|
3079
|
+
if literal is not None:
|
|
3080
|
+
expr = ir.Expr()
|
|
3081
|
+
expr.literal.CopyFrom(literal)
|
|
3082
|
+
fn_call.kwargs.append(ir.Kwarg(name=param.name, value=expr))
|
|
3083
|
+
|
|
3084
|
+
def _extract_action_call_from_awaitable(self, node: ast.expr) -> Optional[ir.ActionCall]:
|
|
3085
|
+
"""Extract action call from an awaitable expression."""
|
|
3086
|
+
if isinstance(node, ast.Call):
|
|
3087
|
+
return self._extract_action_call_from_call(node)
|
|
3088
|
+
return None
|
|
3089
|
+
|
|
3090
|
+
def _extract_action_call_from_call(self, node: ast.Call) -> Optional[ir.ActionCall]:
|
|
3091
|
+
"""Extract action call info from a Call node.
|
|
3092
|
+
|
|
3093
|
+
Converts positional arguments to keyword arguments using the action's
|
|
3094
|
+
signature introspection. This ensures all arguments are named in the IR.
|
|
3095
|
+
|
|
3096
|
+
Pydantic models and dataclass constructors passed as arguments are
|
|
3097
|
+
automatically converted to dict expressions.
|
|
3098
|
+
"""
|
|
3099
|
+
action_name = self._get_action_name(node.func)
|
|
3100
|
+
if not action_name:
|
|
3101
|
+
return None
|
|
3102
|
+
|
|
3103
|
+
if action_name not in self._action_defs:
|
|
3104
|
+
return None
|
|
3105
|
+
|
|
3106
|
+
action_def = self._action_defs[action_name]
|
|
3107
|
+
action_call = ir.ActionCall(action_name=action_def.action_name)
|
|
3108
|
+
|
|
3109
|
+
# Set the module name so the worker knows where to find the action
|
|
3110
|
+
if action_def.module_name:
|
|
3111
|
+
action_call.module_name = action_def.module_name
|
|
3112
|
+
|
|
3113
|
+
# Get parameter names from signature for positional arg conversion
|
|
3114
|
+
param_names = list(action_def.signature.parameters.keys())
|
|
3115
|
+
|
|
3116
|
+
# Convert positional args to kwargs using signature introspection
|
|
3117
|
+
# Model constructors are converted to dict expressions
|
|
3118
|
+
for i, arg in enumerate(node.args):
|
|
3119
|
+
if i < len(param_names):
|
|
3120
|
+
expr = self._expr_to_ir_with_model_coercion(arg)
|
|
3121
|
+
if expr:
|
|
3122
|
+
kwarg = ir.Kwarg(name=param_names[i], value=expr)
|
|
3123
|
+
action_call.kwargs.append(kwarg)
|
|
3124
|
+
|
|
3125
|
+
# Add explicit kwargs
|
|
3126
|
+
# Model constructors are converted to dict expressions
|
|
3127
|
+
for kw in node.keywords:
|
|
3128
|
+
if kw.arg:
|
|
3129
|
+
expr = self._expr_to_ir_with_model_coercion(kw.value)
|
|
3130
|
+
if expr:
|
|
3131
|
+
kwarg = ir.Kwarg(name=kw.arg, value=expr)
|
|
3132
|
+
action_call.kwargs.append(kwarg)
|
|
3133
|
+
|
|
3134
|
+
return action_call
|
|
3135
|
+
|
|
3136
|
+
def _get_action_name(self, func: ast.expr) -> Optional[str]:
|
|
3137
|
+
"""Get the action name from a function expression."""
|
|
3138
|
+
if isinstance(func, ast.Name):
|
|
3139
|
+
return func.id
|
|
3140
|
+
elif isinstance(func, ast.Attribute):
|
|
3141
|
+
return func.attr
|
|
3142
|
+
return None
|
|
3143
|
+
|
|
3144
|
+
def _get_assign_target(self, targets: List[ast.expr]) -> Optional[str]:
|
|
3145
|
+
"""Get the target variable name from assignment targets (single target only)."""
|
|
3146
|
+
if targets and isinstance(targets[0], ast.Name):
|
|
3147
|
+
return targets[0].id
|
|
3148
|
+
return None
|
|
3149
|
+
|
|
3150
|
+
def _get_assign_targets(self, targets: List[ast.expr]) -> List[str]:
|
|
3151
|
+
"""Get all target variable names from assignment targets (including tuple unpacking)."""
|
|
3152
|
+
result: List[str] = []
|
|
3153
|
+
for t in targets:
|
|
3154
|
+
if isinstance(t, ast.Name):
|
|
3155
|
+
result.append(t.id)
|
|
3156
|
+
elif isinstance(t, ast.Subscript):
|
|
3157
|
+
formatted = _format_subscript_target(t)
|
|
3158
|
+
if formatted:
|
|
3159
|
+
result.append(formatted)
|
|
3160
|
+
elif isinstance(t, ast.Tuple):
|
|
3161
|
+
for elt in t.elts:
|
|
3162
|
+
if isinstance(elt, ast.Name):
|
|
3163
|
+
result.append(elt.id)
|
|
3164
|
+
return result
|
|
3165
|
+
|
|
3166
|
+
|
|
3167
|
+
def _make_span(node: ast.AST) -> ir.Span:
|
|
3168
|
+
"""Create a Span from an AST node."""
|
|
3169
|
+
return ir.Span(
|
|
3170
|
+
start_line=getattr(node, "lineno", 0),
|
|
3171
|
+
start_col=getattr(node, "col_offset", 0),
|
|
3172
|
+
end_line=getattr(node, "end_lineno", 0) or 0,
|
|
3173
|
+
end_col=getattr(node, "end_col_offset", 0) or 0,
|
|
3174
|
+
)
|
|
3175
|
+
|
|
3176
|
+
|
|
3177
|
+
def _attribute_chain(node: ast.Attribute) -> Optional[List[str]]:
|
|
3178
|
+
parts: List[str] = []
|
|
3179
|
+
current: ast.AST = node
|
|
3180
|
+
while isinstance(current, ast.Attribute):
|
|
3181
|
+
parts.append(current.attr)
|
|
3182
|
+
current = current.value
|
|
3183
|
+
if isinstance(current, ast.Name):
|
|
3184
|
+
parts.append(current.id)
|
|
3185
|
+
return list(reversed(parts))
|
|
3186
|
+
return None
|
|
3187
|
+
|
|
3188
|
+
|
|
3189
|
+
def _resolve_enum_attribute_value(
|
|
3190
|
+
node: ast.Attribute,
|
|
3191
|
+
module_globals: Mapping[str, Any],
|
|
3192
|
+
) -> Optional[Any]:
|
|
3193
|
+
chain = _attribute_chain(node)
|
|
3194
|
+
if not chain or len(chain) < 2:
|
|
3195
|
+
return None
|
|
3196
|
+
|
|
3197
|
+
current = module_globals.get(chain[0])
|
|
3198
|
+
if current is None:
|
|
3199
|
+
return None
|
|
3200
|
+
|
|
3201
|
+
for part in chain[1:-1]:
|
|
3202
|
+
try:
|
|
3203
|
+
current_dict = current.__dict__
|
|
3204
|
+
except AttributeError:
|
|
3205
|
+
return None
|
|
3206
|
+
current = current_dict.get(part)
|
|
3207
|
+
if current is None:
|
|
3208
|
+
return None
|
|
3209
|
+
|
|
3210
|
+
member_name = chain[-1]
|
|
3211
|
+
if isinstance(current, EnumMeta):
|
|
3212
|
+
member = current.__members__.get(member_name)
|
|
3213
|
+
if member is None:
|
|
3214
|
+
return None
|
|
3215
|
+
return member.value
|
|
3216
|
+
|
|
3217
|
+
return None
|
|
3218
|
+
|
|
3219
|
+
|
|
3220
|
+
def _expr_to_ir(
|
|
3221
|
+
expr: ast.AST,
|
|
3222
|
+
model_converter: Optional[Callable[[ast.Call], Optional[ir.Expr]]] = None,
|
|
3223
|
+
enum_resolver: Optional[Callable[[ast.Attribute], Optional[ir.Expr]]] = None,
|
|
3224
|
+
) -> Optional[ir.Expr]:
|
|
3225
|
+
"""Convert Python AST expression to IR Expr."""
|
|
3226
|
+
result = ir.Expr(span=_make_span(expr))
|
|
3227
|
+
|
|
3228
|
+
if isinstance(expr, ast.Call) and model_converter:
|
|
3229
|
+
converted = model_converter(expr)
|
|
3230
|
+
if converted:
|
|
3231
|
+
return converted
|
|
3232
|
+
|
|
3233
|
+
if isinstance(expr, ast.Name):
|
|
3234
|
+
result.variable.CopyFrom(ir.Variable(name=expr.id))
|
|
3235
|
+
return result
|
|
3236
|
+
|
|
3237
|
+
if isinstance(expr, ast.Constant):
|
|
3238
|
+
literal = _constant_to_literal(expr.value)
|
|
3239
|
+
if literal:
|
|
3240
|
+
result.literal.CopyFrom(literal)
|
|
3241
|
+
return result
|
|
3242
|
+
|
|
3243
|
+
if isinstance(expr, ast.BinOp):
|
|
3244
|
+
left = _expr_to_ir(
|
|
3245
|
+
expr.left,
|
|
3246
|
+
model_converter=model_converter,
|
|
3247
|
+
enum_resolver=enum_resolver,
|
|
3248
|
+
)
|
|
3249
|
+
right = _expr_to_ir(
|
|
3250
|
+
expr.right,
|
|
3251
|
+
model_converter=model_converter,
|
|
3252
|
+
enum_resolver=enum_resolver,
|
|
3253
|
+
)
|
|
3254
|
+
op = _bin_op_to_ir(expr.op)
|
|
3255
|
+
if left and right and op:
|
|
3256
|
+
result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
|
|
3257
|
+
return result
|
|
3258
|
+
|
|
3259
|
+
if isinstance(expr, ast.UnaryOp):
|
|
3260
|
+
operand = _expr_to_ir(
|
|
3261
|
+
expr.operand,
|
|
3262
|
+
model_converter=model_converter,
|
|
3263
|
+
enum_resolver=enum_resolver,
|
|
3264
|
+
)
|
|
3265
|
+
op = _unary_op_to_ir(expr.op)
|
|
3266
|
+
if operand and op:
|
|
3267
|
+
result.unary_op.CopyFrom(ir.UnaryOp(op=op, operand=operand))
|
|
3268
|
+
return result
|
|
3269
|
+
|
|
3270
|
+
if isinstance(expr, ast.Compare):
|
|
3271
|
+
left = _expr_to_ir(
|
|
3272
|
+
expr.left,
|
|
3273
|
+
model_converter=model_converter,
|
|
3274
|
+
enum_resolver=enum_resolver,
|
|
3275
|
+
)
|
|
3276
|
+
if not left:
|
|
3277
|
+
return None
|
|
3278
|
+
# For simplicity, handle single comparison
|
|
3279
|
+
if expr.ops and expr.comparators:
|
|
3280
|
+
op = _cmp_op_to_ir(expr.ops[0])
|
|
3281
|
+
right = _expr_to_ir(
|
|
3282
|
+
expr.comparators[0],
|
|
3283
|
+
model_converter=model_converter,
|
|
3284
|
+
enum_resolver=enum_resolver,
|
|
3285
|
+
)
|
|
3286
|
+
if op and right:
|
|
3287
|
+
result.binary_op.CopyFrom(ir.BinaryOp(left=left, op=op, right=right))
|
|
3288
|
+
return result
|
|
3289
|
+
|
|
3290
|
+
if isinstance(expr, ast.BoolOp):
|
|
3291
|
+
values = [
|
|
3292
|
+
_expr_to_ir(v, model_converter=model_converter, enum_resolver=enum_resolver)
|
|
3293
|
+
for v in expr.values
|
|
3294
|
+
]
|
|
3295
|
+
if all(v for v in values):
|
|
3296
|
+
op = _bool_op_to_ir(expr.op)
|
|
3297
|
+
if op and len(values) >= 2:
|
|
3298
|
+
# Chain boolean ops: a and b and c -> (a and b) and c
|
|
3299
|
+
result_expr = values[0]
|
|
3300
|
+
for v in values[1:]:
|
|
3301
|
+
if result_expr and v:
|
|
3302
|
+
new_result = ir.Expr()
|
|
3303
|
+
new_result.binary_op.CopyFrom(ir.BinaryOp(left=result_expr, op=op, right=v))
|
|
3304
|
+
result_expr = new_result
|
|
3305
|
+
return result_expr
|
|
3306
|
+
|
|
3307
|
+
if isinstance(expr, ast.List):
|
|
3308
|
+
elements = [
|
|
3309
|
+
_expr_to_ir(e, model_converter=model_converter, enum_resolver=enum_resolver)
|
|
3310
|
+
for e in expr.elts
|
|
3311
|
+
]
|
|
3312
|
+
if all(e for e in elements):
|
|
3313
|
+
list_expr = ir.ListExpr(elements=[e for e in elements if e])
|
|
3314
|
+
result.list.CopyFrom(list_expr)
|
|
3315
|
+
return result
|
|
3316
|
+
|
|
3317
|
+
if isinstance(expr, ast.Dict):
|
|
3318
|
+
entries: List[ir.DictEntry] = []
|
|
3319
|
+
for k, v in zip(expr.keys, expr.values, strict=False):
|
|
3320
|
+
if k:
|
|
3321
|
+
key_expr = _expr_to_ir(
|
|
3322
|
+
k,
|
|
3323
|
+
model_converter=model_converter,
|
|
3324
|
+
enum_resolver=enum_resolver,
|
|
3325
|
+
)
|
|
3326
|
+
value_expr = _expr_to_ir(
|
|
3327
|
+
v,
|
|
3328
|
+
model_converter=model_converter,
|
|
3329
|
+
enum_resolver=enum_resolver,
|
|
3330
|
+
)
|
|
3331
|
+
if key_expr and value_expr:
|
|
3332
|
+
entries.append(ir.DictEntry(key=key_expr, value=value_expr))
|
|
3333
|
+
result.dict.CopyFrom(ir.DictExpr(entries=entries))
|
|
3334
|
+
return result
|
|
3335
|
+
|
|
3336
|
+
if isinstance(expr, ast.Subscript):
|
|
3337
|
+
obj = _expr_to_ir(
|
|
3338
|
+
expr.value,
|
|
3339
|
+
model_converter=model_converter,
|
|
3340
|
+
enum_resolver=enum_resolver,
|
|
3341
|
+
)
|
|
3342
|
+
index = (
|
|
3343
|
+
_expr_to_ir(
|
|
3344
|
+
expr.slice,
|
|
3345
|
+
model_converter=model_converter,
|
|
3346
|
+
enum_resolver=enum_resolver,
|
|
3347
|
+
)
|
|
3348
|
+
if isinstance(expr.slice, ast.AST)
|
|
3349
|
+
else None
|
|
3350
|
+
)
|
|
3351
|
+
if obj and index:
|
|
3352
|
+
result.index.CopyFrom(ir.IndexAccess(object=obj, index=index))
|
|
3353
|
+
return result
|
|
3354
|
+
|
|
3355
|
+
if isinstance(expr, ast.Attribute):
|
|
3356
|
+
if enum_resolver:
|
|
3357
|
+
resolved = enum_resolver(expr)
|
|
3358
|
+
if resolved:
|
|
3359
|
+
return resolved
|
|
3360
|
+
obj = _expr_to_ir(
|
|
3361
|
+
expr.value,
|
|
3362
|
+
model_converter=model_converter,
|
|
3363
|
+
enum_resolver=enum_resolver,
|
|
3364
|
+
)
|
|
3365
|
+
if obj:
|
|
3366
|
+
result.dot.CopyFrom(ir.DotAccess(object=obj, attribute=expr.attr))
|
|
3367
|
+
return result
|
|
3368
|
+
|
|
3369
|
+
if isinstance(expr, ast.Await) and isinstance(expr.value, ast.Call):
|
|
3370
|
+
func_name = _get_func_name(expr.value.func)
|
|
3371
|
+
if func_name:
|
|
3372
|
+
args = [
|
|
3373
|
+
_expr_to_ir(
|
|
3374
|
+
a,
|
|
3375
|
+
model_converter=model_converter,
|
|
3376
|
+
enum_resolver=enum_resolver,
|
|
3377
|
+
)
|
|
3378
|
+
for a in expr.value.args
|
|
3379
|
+
]
|
|
3380
|
+
kwargs: List[ir.Kwarg] = []
|
|
3381
|
+
for kw in expr.value.keywords:
|
|
3382
|
+
if kw.arg:
|
|
3383
|
+
kw_expr = _expr_to_ir(
|
|
3384
|
+
kw.value,
|
|
3385
|
+
model_converter=model_converter,
|
|
3386
|
+
enum_resolver=enum_resolver,
|
|
3387
|
+
)
|
|
3388
|
+
if kw_expr:
|
|
3389
|
+
kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
|
|
3390
|
+
func_call = ir.FunctionCall(
|
|
3391
|
+
name=func_name,
|
|
3392
|
+
args=[a for a in args if a],
|
|
3393
|
+
kwargs=kwargs,
|
|
3394
|
+
)
|
|
3395
|
+
global_function = _global_function_for_call(func_name, expr.value)
|
|
3396
|
+
if global_function is not None:
|
|
3397
|
+
func_call.global_function = global_function
|
|
3398
|
+
result.function_call.CopyFrom(func_call)
|
|
3399
|
+
return result
|
|
3400
|
+
|
|
3401
|
+
if isinstance(expr, ast.Call):
|
|
3402
|
+
# Function call
|
|
3403
|
+
if not _is_self_method_call(expr):
|
|
3404
|
+
func_name = _get_func_name(expr.func) or "unknown"
|
|
3405
|
+
if isinstance(expr.func, ast.Attribute):
|
|
3406
|
+
line = expr.lineno if hasattr(expr, "lineno") else None
|
|
3407
|
+
col = expr.col_offset if hasattr(expr, "col_offset") else None
|
|
3408
|
+
raise UnsupportedPatternError(
|
|
3409
|
+
f"Calling synchronous function '{func_name}()' directly is not supported",
|
|
3410
|
+
RECOMMENDATIONS["sync_function_call"],
|
|
3411
|
+
line=line,
|
|
3412
|
+
col=col,
|
|
3413
|
+
)
|
|
3414
|
+
if func_name not in ALLOWED_SYNC_FUNCTIONS:
|
|
3415
|
+
line = expr.lineno if hasattr(expr, "lineno") else None
|
|
3416
|
+
col = expr.col_offset if hasattr(expr, "col_offset") else None
|
|
3417
|
+
raise UnsupportedPatternError(
|
|
3418
|
+
f"Calling synchronous function '{func_name}()' directly is not supported",
|
|
3419
|
+
RECOMMENDATIONS["sync_function_call"],
|
|
3420
|
+
line=line,
|
|
3421
|
+
col=col,
|
|
3422
|
+
)
|
|
3423
|
+
func_name = _get_func_name(expr.func)
|
|
3424
|
+
if func_name:
|
|
3425
|
+
args = [
|
|
3426
|
+
_expr_to_ir(
|
|
3427
|
+
a,
|
|
3428
|
+
model_converter=model_converter,
|
|
3429
|
+
enum_resolver=enum_resolver,
|
|
3430
|
+
)
|
|
3431
|
+
for a in expr.args
|
|
3432
|
+
]
|
|
3433
|
+
kwargs: List[ir.Kwarg] = []
|
|
3434
|
+
for kw in expr.keywords:
|
|
3435
|
+
if kw.arg:
|
|
3436
|
+
kw_expr = _expr_to_ir(
|
|
3437
|
+
kw.value,
|
|
3438
|
+
model_converter=model_converter,
|
|
3439
|
+
enum_resolver=enum_resolver,
|
|
3440
|
+
)
|
|
3441
|
+
if kw_expr:
|
|
3442
|
+
kwargs.append(ir.Kwarg(name=kw.arg, value=kw_expr))
|
|
3443
|
+
func_call = ir.FunctionCall(
|
|
3444
|
+
name=func_name,
|
|
3445
|
+
args=[a for a in args if a],
|
|
3446
|
+
kwargs=kwargs,
|
|
3447
|
+
)
|
|
3448
|
+
global_function = _global_function_for_call(func_name, expr)
|
|
3449
|
+
if global_function is not None:
|
|
3450
|
+
func_call.global_function = global_function
|
|
3451
|
+
result.function_call.CopyFrom(func_call)
|
|
3452
|
+
return result
|
|
3453
|
+
|
|
3454
|
+
if isinstance(expr, ast.Tuple):
|
|
3455
|
+
# Handle tuple as list for now
|
|
3456
|
+
elements = [
|
|
3457
|
+
_expr_to_ir(e, model_converter=model_converter, enum_resolver=enum_resolver)
|
|
3458
|
+
for e in expr.elts
|
|
3459
|
+
]
|
|
3460
|
+
if all(e for e in elements):
|
|
3461
|
+
list_expr = ir.ListExpr(elements=[e for e in elements if e])
|
|
3462
|
+
result.list.CopyFrom(list_expr)
|
|
3463
|
+
return result
|
|
3464
|
+
|
|
3465
|
+
# Check for unsupported expression types
|
|
3466
|
+
_check_unsupported_expression(expr)
|
|
3467
|
+
|
|
3468
|
+
return None
|
|
3469
|
+
|
|
3470
|
+
|
|
3471
|
+
def _check_unsupported_expression(expr: ast.AST) -> None:
|
|
3472
|
+
"""Check for unsupported expression types and raise descriptive errors."""
|
|
3473
|
+
line = getattr(expr, "lineno", None)
|
|
3474
|
+
col = getattr(expr, "col_offset", None)
|
|
3475
|
+
|
|
3476
|
+
if isinstance(expr, ast.Constant):
|
|
3477
|
+
if _constant_to_literal(expr.value) is None:
|
|
3478
|
+
raise UnsupportedPatternError(
|
|
3479
|
+
f"Unsupported literal type '{type(expr.value).__name__}'",
|
|
3480
|
+
RECOMMENDATIONS["unsupported_literal"],
|
|
3481
|
+
line=line,
|
|
3482
|
+
col=col,
|
|
3483
|
+
)
|
|
3484
|
+
|
|
3485
|
+
if isinstance(expr, ast.JoinedStr):
|
|
3486
|
+
raise UnsupportedPatternError(
|
|
3487
|
+
"F-strings are not supported",
|
|
3488
|
+
RECOMMENDATIONS["fstring"],
|
|
3489
|
+
line=line,
|
|
3490
|
+
col=col,
|
|
3491
|
+
)
|
|
3492
|
+
elif isinstance(expr, ast.Lambda):
|
|
3493
|
+
raise UnsupportedPatternError(
|
|
3494
|
+
"Lambda expressions are not supported",
|
|
3495
|
+
RECOMMENDATIONS["lambda"],
|
|
3496
|
+
line=line,
|
|
3497
|
+
col=col,
|
|
3498
|
+
)
|
|
3499
|
+
elif isinstance(expr, ast.ListComp):
|
|
3500
|
+
raise UnsupportedPatternError(
|
|
3501
|
+
"List comprehensions are not supported in this context",
|
|
3502
|
+
RECOMMENDATIONS["list_comprehension"],
|
|
3503
|
+
line=line,
|
|
3504
|
+
col=col,
|
|
3505
|
+
)
|
|
3506
|
+
elif isinstance(expr, ast.DictComp):
|
|
3507
|
+
raise UnsupportedPatternError(
|
|
3508
|
+
"Dict comprehensions are not supported in this context",
|
|
3509
|
+
RECOMMENDATIONS["dict_comprehension"],
|
|
3510
|
+
line=line,
|
|
3511
|
+
col=col,
|
|
3512
|
+
)
|
|
3513
|
+
elif isinstance(expr, ast.SetComp):
|
|
3514
|
+
raise UnsupportedPatternError(
|
|
3515
|
+
"Set comprehensions are not supported",
|
|
3516
|
+
RECOMMENDATIONS["set_comprehension"],
|
|
3517
|
+
line=line,
|
|
3518
|
+
col=col,
|
|
3519
|
+
)
|
|
3520
|
+
elif isinstance(expr, ast.GeneratorExp):
|
|
3521
|
+
raise UnsupportedPatternError(
|
|
3522
|
+
"Generator expressions are not supported",
|
|
3523
|
+
RECOMMENDATIONS["generator"],
|
|
3524
|
+
line=line,
|
|
3525
|
+
col=col,
|
|
3526
|
+
)
|
|
3527
|
+
elif isinstance(expr, ast.NamedExpr):
|
|
3528
|
+
raise UnsupportedPatternError(
|
|
3529
|
+
"The walrus operator (:=) is not supported",
|
|
3530
|
+
RECOMMENDATIONS["walrus"],
|
|
3531
|
+
line=line,
|
|
3532
|
+
col=col,
|
|
3533
|
+
)
|
|
3534
|
+
elif isinstance(expr, ast.Yield) or isinstance(expr, ast.YieldFrom):
|
|
3535
|
+
raise UnsupportedPatternError(
|
|
3536
|
+
"Yield expressions are not supported",
|
|
3537
|
+
RECOMMENDATIONS["yield_statement"],
|
|
3538
|
+
line=line,
|
|
3539
|
+
col=col,
|
|
3540
|
+
)
|
|
3541
|
+
elif isinstance(expr, ast.expr):
|
|
3542
|
+
raise UnsupportedPatternError(
|
|
3543
|
+
f"Unsupported expression type '{type(expr).__name__}'",
|
|
3544
|
+
RECOMMENDATIONS["unsupported_expression"],
|
|
3545
|
+
line=line,
|
|
3546
|
+
col=col,
|
|
3547
|
+
)
|
|
3548
|
+
|
|
3549
|
+
|
|
3550
|
+
def _format_subscript_target(target: ast.Subscript) -> Optional[str]:
|
|
3551
|
+
"""Convert a subscript target to a string representation for tracking."""
|
|
3552
|
+
if not isinstance(target.value, ast.Name):
|
|
3553
|
+
return None
|
|
3554
|
+
|
|
3555
|
+
base = target.value.id
|
|
3556
|
+
try:
|
|
3557
|
+
index_str = ast.unparse(target.slice)
|
|
3558
|
+
except Exception:
|
|
3559
|
+
return None
|
|
3560
|
+
|
|
3561
|
+
return f"{base}[{index_str}]"
|
|
3562
|
+
|
|
3563
|
+
|
|
3564
|
+
def _constant_to_literal(value: Any) -> Optional[ir.Literal]:
|
|
3565
|
+
"""Convert a Python constant to IR Literal."""
|
|
3566
|
+
literal = ir.Literal()
|
|
3567
|
+
if value is None:
|
|
3568
|
+
literal.is_none = True
|
|
3569
|
+
elif isinstance(value, bool):
|
|
3570
|
+
literal.bool_value = value
|
|
3571
|
+
elif isinstance(value, int):
|
|
3572
|
+
literal.int_value = value
|
|
3573
|
+
elif isinstance(value, float):
|
|
3574
|
+
literal.float_value = value
|
|
3575
|
+
elif isinstance(value, str):
|
|
3576
|
+
literal.string_value = value
|
|
3577
|
+
else:
|
|
3578
|
+
return None
|
|
3579
|
+
return literal
|
|
3580
|
+
|
|
3581
|
+
|
|
3582
|
+
def _bin_op_to_ir(op: ast.operator) -> Optional[ir.BinaryOperator]:
|
|
3583
|
+
"""Convert Python binary operator to IR BinaryOperator."""
|
|
3584
|
+
mapping = {
|
|
3585
|
+
ast.Add: ir.BinaryOperator.BINARY_OP_ADD,
|
|
3586
|
+
ast.Sub: ir.BinaryOperator.BINARY_OP_SUB,
|
|
3587
|
+
ast.Mult: ir.BinaryOperator.BINARY_OP_MUL,
|
|
3588
|
+
ast.Div: ir.BinaryOperator.BINARY_OP_DIV,
|
|
3589
|
+
ast.FloorDiv: ir.BinaryOperator.BINARY_OP_FLOOR_DIV,
|
|
3590
|
+
ast.Mod: ir.BinaryOperator.BINARY_OP_MOD,
|
|
3591
|
+
}
|
|
3592
|
+
return mapping.get(type(op))
|
|
3593
|
+
|
|
3594
|
+
|
|
3595
|
+
def _unary_op_to_ir(op: ast.unaryop) -> Optional[ir.UnaryOperator]:
|
|
3596
|
+
"""Convert Python unary operator to IR UnaryOperator."""
|
|
3597
|
+
mapping = {
|
|
3598
|
+
ast.USub: ir.UnaryOperator.UNARY_OP_NEG,
|
|
3599
|
+
ast.Not: ir.UnaryOperator.UNARY_OP_NOT,
|
|
3600
|
+
}
|
|
3601
|
+
return mapping.get(type(op))
|
|
3602
|
+
|
|
3603
|
+
|
|
3604
|
+
def _cmp_op_to_ir(op: ast.cmpop) -> Optional[ir.BinaryOperator]:
|
|
3605
|
+
"""Convert Python comparison operator to IR BinaryOperator."""
|
|
3606
|
+
mapping = {
|
|
3607
|
+
ast.Eq: ir.BinaryOperator.BINARY_OP_EQ,
|
|
3608
|
+
ast.NotEq: ir.BinaryOperator.BINARY_OP_NE,
|
|
3609
|
+
ast.Lt: ir.BinaryOperator.BINARY_OP_LT,
|
|
3610
|
+
ast.LtE: ir.BinaryOperator.BINARY_OP_LE,
|
|
3611
|
+
ast.Gt: ir.BinaryOperator.BINARY_OP_GT,
|
|
3612
|
+
ast.GtE: ir.BinaryOperator.BINARY_OP_GE,
|
|
3613
|
+
ast.In: ir.BinaryOperator.BINARY_OP_IN,
|
|
3614
|
+
ast.NotIn: ir.BinaryOperator.BINARY_OP_NOT_IN,
|
|
3615
|
+
}
|
|
3616
|
+
return mapping.get(type(op))
|
|
3617
|
+
|
|
3618
|
+
|
|
3619
|
+
def _bool_op_to_ir(op: ast.boolop) -> Optional[ir.BinaryOperator]:
|
|
3620
|
+
"""Convert Python boolean operator to IR BinaryOperator."""
|
|
3621
|
+
mapping = {
|
|
3622
|
+
ast.And: ir.BinaryOperator.BINARY_OP_AND,
|
|
3623
|
+
ast.Or: ir.BinaryOperator.BINARY_OP_OR,
|
|
3624
|
+
}
|
|
3625
|
+
return mapping.get(type(op))
|
|
3626
|
+
|
|
3627
|
+
|
|
3628
|
+
def _get_func_name(func: ast.expr) -> Optional[str]:
|
|
3629
|
+
"""Get function name from a Call's func attribute."""
|
|
3630
|
+
if isinstance(func, ast.Name):
|
|
3631
|
+
return func.id
|
|
3632
|
+
elif isinstance(func, ast.Attribute):
|
|
3633
|
+
# For method calls like obj.method, return full dotted name
|
|
3634
|
+
parts = []
|
|
3635
|
+
current = func
|
|
3636
|
+
while isinstance(current, ast.Attribute):
|
|
3637
|
+
parts.append(current.attr)
|
|
3638
|
+
current = current.value
|
|
3639
|
+
if isinstance(current, ast.Name):
|
|
3640
|
+
parts.append(current.id)
|
|
3641
|
+
name = ".".join(reversed(parts))
|
|
3642
|
+
if name.startswith("self."):
|
|
3643
|
+
return name[5:]
|
|
3644
|
+
return name
|
|
3645
|
+
return None
|
|
3646
|
+
|
|
3647
|
+
|
|
3648
|
+
def _is_self_method_call(node: ast.Call) -> bool:
|
|
3649
|
+
"""Return True if the call is a direct self.method(...) invocation."""
|
|
3650
|
+
func = node.func
|
|
3651
|
+
return (
|
|
3652
|
+
isinstance(func, ast.Attribute)
|
|
3653
|
+
and isinstance(func.value, ast.Name)
|
|
3654
|
+
and func.value.id == "self"
|
|
3655
|
+
)
|
|
3656
|
+
|
|
3657
|
+
|
|
3658
|
+
def _global_function_for_call(
|
|
3659
|
+
func_name: str, node: ast.Call
|
|
3660
|
+
) -> Optional[ir.GlobalFunction.ValueType]:
|
|
3661
|
+
"""Return the GlobalFunction enum value for supported globals."""
|
|
3662
|
+
if _is_self_method_call(node):
|
|
3663
|
+
return None
|
|
3664
|
+
return GLOBAL_FUNCTIONS.get(func_name)
|