imandrax-codegen 18.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,48 @@
1
+ #!/usr/bin/env python
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import typer
6
+ from imandrax_codegen.test_gen import gen_test_cases
7
+ from imandrax_codegen.unparse import unparse
8
+
9
+ app = typer.Typer()
10
+
11
+
12
+ @app.command()
13
+ def main(
14
+ iml_path: str = typer.Argument(
15
+ help='Path of IML file to generate test cases (use "-" to read from stdin)',
16
+ ),
17
+ function: str = typer.Option(
18
+ ...,
19
+ '-f',
20
+ '--function',
21
+ help='Name of function to generate test cases for',
22
+ ),
23
+ output: str | None = typer.Option(
24
+ None,
25
+ '-o',
26
+ '--output',
27
+ help='Output file path (defaults to stdout)',
28
+ ),
29
+ ) -> None:
30
+ """Generate test cases for IML."""
31
+ # Read input from stdin or file
32
+ if iml_path == '-':
33
+ iml = sys.stdin.read()
34
+ else:
35
+ iml = Path(iml_path).read_text()
36
+
37
+ test_case_stmts = gen_test_cases(iml, function)
38
+ result = unparse(test_case_stmts)
39
+
40
+ # Write output to file or stdout
41
+ if output:
42
+ Path(output).write_text(result)
43
+ else:
44
+ typer.echo(result)
45
+
46
+
47
+ if __name__ == '__main__':
48
+ app()
Binary file
@@ -0,0 +1,115 @@
1
+ import base64
2
+ import json
3
+ import subprocess
4
+ import sys
5
+ from functools import singledispatch
6
+ from pathlib import Path
7
+ from typing import Literal
8
+
9
+ import imandrax_codegen.ast_types as ast_types
10
+ from imandrax_api import url_dev, url_prod # noqa: F401
11
+ from imandrax_api_models import Art
12
+ from imandrax_codegen.ast_deserialize import stmts_of_json
13
+
14
+
15
+ def find_art_parse_exe() -> Path:
16
+ """Find the art_parse executable.
17
+
18
+ Raises:
19
+ RuntimeError: If the executable is not found (unsupported platform)
20
+ """
21
+ exe_path = Path(__file__).parent / 'art_parse.exe'
22
+
23
+ if sys.platform != 'darwin':
24
+ raise ValueError(
25
+ 'Only MacOS is supported for now. Please wait for the next release.'
26
+ )
27
+
28
+ if not exe_path.exists():
29
+ raise ValueError(
30
+ f'art_parse.exe not found in {exe_path.parent}. '
31
+ f'The package might not be built correctly.'
32
+ )
33
+ return exe_path
34
+
35
+
36
+ CODEGEN_EXE_PATH = find_art_parse_exe()
37
+
38
+
39
+ def _convert_to_standard_base64(data: str | bytes) -> str:
40
+ """Convert bytes or URL-safe base64 string to standard base64.
41
+
42
+ Handles two cases:
43
+ 1. If data is bytes: directly encode to standard base64
44
+ 2. If data is a URL-safe base64 string: convert to standard base64
45
+
46
+ Pydantic serializes bytes as URL-safe base64 (using - and _ instead of + and /),
47
+ but OCaml's Base64.decode_exn expects standard base64 encoding.
48
+
49
+ Args:
50
+ data: Either raw bytes or URL-safe base64 string
51
+
52
+ Returns:
53
+ Standard base64 string
54
+ """
55
+ if isinstance(data, bytes):
56
+ # Directly encode bytes to standard base64
57
+ return base64.b64encode(data).decode('ascii')
58
+
59
+ # It's a string - assume it's URL-safe base64
60
+ # Add padding if needed
61
+ padding = (4 - len(data) % 4) % 4
62
+ urlsafe_b64_padded = data + ('=' * padding)
63
+
64
+ # Decode URL-safe and re-encode as standard base64
65
+ decoded_bytes = base64.urlsafe_b64decode(urlsafe_b64_padded)
66
+ return base64.b64encode(decoded_bytes).decode('ascii')
67
+
68
+
69
+ def _serialize_artifact(art: Art) -> str:
70
+ """Serialize an artifact BaseModel to a JSON string."""
71
+ art_dict = art.model_dump()
72
+ art_dict['data'] = _convert_to_standard_base64(art_dict['data'])
73
+ return json.dumps(art_dict)
74
+
75
+
76
+ @singledispatch
77
+ def ast_of_art(
78
+ art: str | Art, mode: Literal['fun-decomp', 'model', 'decl']
79
+ ) -> list[ast_types.stmt]:
80
+ raise NotImplementedError(f'Only Art and str are supported, got {type(art)}')
81
+
82
+
83
+ @ast_of_art.register
84
+ def _(
85
+ art: str,
86
+ mode: Literal['fun-decomp', 'model', 'decl'],
87
+ ) -> list[ast_types.stmt]:
88
+ """Use the codegen executable to generate ASTs for a given artifact."""
89
+ result = subprocess.run(
90
+ [
91
+ CODEGEN_EXE_PATH,
92
+ '-',
93
+ '-',
94
+ '--mode',
95
+ mode,
96
+ ],
97
+ check=False,
98
+ input=art,
99
+ text=True,
100
+ capture_output=True,
101
+ )
102
+ if result.returncode != 0:
103
+ raise RuntimeError(f'Failed to run generate AST: {result.stderr}')
104
+ return stmts_of_json(result.stdout)
105
+
106
+
107
+ @ast_of_art.register
108
+ def _(
109
+ art: Art,
110
+ mode: Literal['fun-decomp', 'model', 'decl'],
111
+ ) -> list[ast_types.stmt]:
112
+ return ast_of_art(_serialize_artifact(art), mode)
113
+
114
+
115
+ # END [[ast_of_art]]>
@@ -0,0 +1,102 @@
1
+ """Simple recursive deserializer for OCaml yojson AST format.
2
+
3
+ OCaml format: ["Tag", {...}] for variants with fields, ["Tag"] for empty
4
+ variants.
5
+ Location info is NOT deserialized - use ast.fix to add it later.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Any, cast
11
+
12
+ from . import ast_types as ast
13
+
14
+
15
+ def deserialize_constant_value(value: Any) -> Any:
16
+ """Deserialize constant value from OCaml tagged format."""
17
+ if not isinstance(value, list):
18
+ return value
19
+
20
+ value_list = cast(list[Any], value)
21
+ if len(value_list) < 1 or not isinstance(value_list[0], str):
22
+ return cast(Any, value)
23
+
24
+ tag: str = value_list[0]
25
+ if tag == 'Unit':
26
+ return None
27
+ elif tag in ('String', 'Bytes', 'Bool', 'Int', 'Float'):
28
+ return value_list[1]
29
+ return cast(Any, value)
30
+
31
+
32
+ def deserialize(value: Any) -> Any:
33
+ """Recursively deserialize OCaml yojson to Python AST objects."""
34
+ if value is None:
35
+ return None
36
+
37
+ if isinstance(value, (str, int, float, bool, bytes)):
38
+ return value
39
+
40
+ if isinstance(value, dict):
41
+ # Recursively deserialize dict values
42
+ value_dict = cast(dict[str, Any], value)
43
+ result: dict[str, Any] = {}
44
+ for k, v in value_dict.items():
45
+ result[k] = deserialize(v)
46
+ return result
47
+
48
+ if isinstance(value, list):
49
+ value_list = cast(list[Any], value)
50
+ # Check if it's a tagged tuple ["Tag", ...] or just a list
51
+ if len(value_list) >= 1 and isinstance(value_list[0], str):
52
+ tag: str = value_list[0]
53
+
54
+ # Empty variant: ["Tag"]
55
+ if len(value_list) == 1:
56
+ return getattr(ast, tag)()
57
+
58
+ # Variant with data: ["Tag", {...}]
59
+ if len(value_list) == 2 and isinstance(value_list[1], dict):
60
+ data: dict[str, Any] = cast(dict[str, Any], value_list[1])
61
+
62
+ # Special handling: OCaml's ExprStmt maps to Python's Expr
63
+ if tag == 'ExprStmt':
64
+ tag = 'Expr'
65
+
66
+ cls = getattr(ast, tag)
67
+
68
+ # Special handling for Constant.value field
69
+ if tag == 'Constant' and 'value' in data:
70
+ kwargs: dict[str, Any] = {}
71
+ for k, v in data.items():
72
+ if k != 'value': # Skip 'value' - handle it specially below
73
+ kwargs[k] = deserialize(v)
74
+ kwargs['value'] = deserialize_constant_value(data['value'])
75
+ return cls(**kwargs)
76
+
77
+ # Recursively deserialize all fields
78
+ kwargs2: dict[str, Any] = {}
79
+ for k, v in data.items():
80
+ kwargs2[k] = deserialize(v)
81
+ return cls(**kwargs2)
82
+
83
+ # Plain list - recursively deserialize elements
84
+ result_list: list[Any] = []
85
+ for item in value_list:
86
+ result_list.append(deserialize(item))
87
+ return result_list
88
+
89
+ return value
90
+
91
+
92
+ def _stmts_of_json_data(json_data: list[Any]) -> list[ast.stmt]:
93
+ """Load a list of statements from OCaml JSON."""
94
+ return deserialize(json_data)
95
+
96
+
97
+ def stmts_of_json(json_string: str) -> list[ast.stmt]:
98
+ """Load statements from a JSON string."""
99
+ import json
100
+
101
+ data = json.loads(json_string)
102
+ return _stmts_of_json_data(data)
@@ -0,0 +1,472 @@
1
+ """This module closely follows the structure from lib/ast.ml with
2
+ dataclass decorators added.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from pydantic import ConfigDict
8
+ from pydantic.dataclasses import dataclass
9
+
10
+ dc = dataclass(config=ConfigDict(arbitrary_types_allowed=True))
11
+
12
+
13
+ @dc
14
+ class AST:
15
+ pass
16
+
17
+
18
+ @dc
19
+ class mod(AST):
20
+ pass
21
+
22
+
23
+ @dc
24
+ class Module(mod):
25
+ body: list[stmt]
26
+
27
+
28
+ @dc
29
+ class stmt:
30
+ # lineno: int
31
+ # col_offset: int
32
+ # end_lineno: int | None
33
+ # end_col_offset: int | None
34
+ pass
35
+
36
+
37
+ @dc
38
+ class ClassDef(stmt):
39
+ name: str
40
+ bases: list[expr]
41
+ keywords: list[keyword]
42
+ body: list[stmt]
43
+ decorator_list: list[expr]
44
+
45
+
46
+ @dc
47
+ class Assign(stmt):
48
+ targets: list[expr]
49
+ value: expr
50
+ type_comment: str | None
51
+
52
+
53
+ @dc
54
+ class AugAssign(stmt):
55
+ target: Name | Attribute | Subscript
56
+ op: operator
57
+ value: expr
58
+
59
+
60
+ @dc
61
+ class AnnAssign(stmt):
62
+ target: Name | Attribute | Subscript
63
+ annotation: expr
64
+ value: expr | None
65
+ simple: int
66
+
67
+
68
+ @dc
69
+ class Expr(stmt):
70
+ # L792
71
+ value: expr
72
+
73
+
74
+ @dc
75
+ class Pass(stmt):
76
+ pass
77
+
78
+
79
+ @dc
80
+ class expr(AST):
81
+ pass
82
+ # lineno: int
83
+ # col_offset: int
84
+ # end_lineno: int | None
85
+ # end_col_offset: int | None
86
+
87
+
88
+ @dc
89
+ class BoolOp(expr):
90
+ op: boolop
91
+ values: list[expr]
92
+
93
+
94
+ @dc
95
+ class BinOp(expr):
96
+ left: expr
97
+ op: operator
98
+ right: expr
99
+
100
+
101
+ @dc
102
+ class UnaryOp(expr):
103
+ op: unaryop
104
+ operand: expr
105
+
106
+
107
+ @dc
108
+ class Dict(expr):
109
+ keys: list[expr | None]
110
+ values: list[expr]
111
+
112
+
113
+ @dc
114
+ class Set(expr):
115
+ elts: list[expr]
116
+
117
+
118
+ @dc
119
+ class Call(expr):
120
+ # L1024
121
+ func: expr
122
+ args: list[expr]
123
+ keywords: list[keyword]
124
+
125
+
126
+ # original: `_ConstantValue: typing_extensions.TypeAlias = ...`
127
+ type _ConstantValue = (
128
+ str | bytes | bool | int | float | complex | None
129
+ # | EllipsisType
130
+ )
131
+
132
+
133
+ @dc
134
+ class Constant(expr):
135
+ # L1037
136
+ value: _ConstantValue
137
+ kind: str | None
138
+
139
+
140
+ @dc
141
+ class Attribute(expr):
142
+ value: expr
143
+ attr: str
144
+ ctx: expr_context
145
+
146
+
147
+ @dc
148
+ class Subscript(expr):
149
+ # L1139
150
+ value: expr
151
+ slice: expr
152
+ ctx: expr_context
153
+
154
+
155
+ @dc
156
+ class Name(expr):
157
+ # L1162
158
+ id: str
159
+ ctx: expr_context
160
+
161
+
162
+ @dc
163
+ class List(expr):
164
+ # L1175
165
+ elts: list[expr]
166
+ ctx: expr_context
167
+
168
+
169
+ @dc
170
+ class Tuple(expr):
171
+ # L1185
172
+ elts: list[expr]
173
+ ctx: expr_context
174
+ # dims: list[expr]
175
+
176
+
177
+ @dc
178
+ class expr_context(AST):
179
+ pass
180
+
181
+
182
+ @dc
183
+ class Load(expr_context):
184
+ # L1239
185
+ pass
186
+
187
+
188
+ @dc
189
+ class Store(expr_context):
190
+ pass
191
+
192
+
193
+ @dc
194
+ class Del(expr_context):
195
+ pass
196
+
197
+
198
+ @dc
199
+ class boolop(AST):
200
+ pass
201
+
202
+
203
+ @dc
204
+ class And(boolop):
205
+ pass
206
+
207
+
208
+ @dc
209
+ class Or(boolop):
210
+ pass
211
+
212
+
213
+ @dc
214
+ class operator(AST):
215
+ pass
216
+
217
+
218
+ @dc
219
+ class Add(operator):
220
+ pass
221
+
222
+
223
+ @dc
224
+ class Sub(operator):
225
+ pass
226
+
227
+
228
+ @dc
229
+ class Mult(operator):
230
+ pass
231
+
232
+
233
+ @dc
234
+ class MatMult(operator):
235
+ pass
236
+
237
+
238
+ @dc
239
+ class Div(operator):
240
+ pass
241
+
242
+
243
+ @dc
244
+ class Mod(operator):
245
+ pass
246
+
247
+
248
+ @dc
249
+ class Pow(operator):
250
+ pass
251
+
252
+
253
+ @dc
254
+ class LShift(operator):
255
+ pass
256
+
257
+
258
+ @dc
259
+ class RShift(operator):
260
+ pass
261
+
262
+
263
+ @dc
264
+ class BitOr(operator):
265
+ pass
266
+
267
+
268
+ @dc
269
+ class BitXor(operator):
270
+ pass
271
+
272
+
273
+ @dc
274
+ class BitAnd(operator):
275
+ pass
276
+
277
+
278
+ @dc
279
+ class FloorDiv(operator):
280
+ pass
281
+
282
+
283
+ @dc
284
+ class unaryop(AST):
285
+ pass
286
+
287
+
288
+ @dc
289
+ class Invert(unaryop):
290
+ pass
291
+
292
+
293
+ @dc
294
+ class Not(unaryop):
295
+ pass
296
+
297
+
298
+ @dc
299
+ class UAdd(unaryop):
300
+ pass
301
+
302
+
303
+ @dc
304
+ class USub(unaryop):
305
+ pass
306
+
307
+
308
+ @dc
309
+ class cmpop(AST):
310
+ pass
311
+
312
+
313
+ @dc
314
+ class Eq(cmpop):
315
+ pass
316
+
317
+
318
+ @dc
319
+ class NotEq(cmpop):
320
+ pass
321
+
322
+
323
+ @dc
324
+ class Lt(cmpop):
325
+ pass
326
+
327
+
328
+ @dc
329
+ class LtE(cmpop):
330
+ pass
331
+
332
+
333
+ @dc
334
+ class Gt(cmpop):
335
+ pass
336
+
337
+
338
+ @dc
339
+ class GtE(cmpop):
340
+ pass
341
+
342
+
343
+ @dc
344
+ class Is(cmpop):
345
+ pass
346
+
347
+
348
+ @dc
349
+ class IsNot(cmpop):
350
+ pass
351
+
352
+
353
+ @dc
354
+ class In(cmpop):
355
+ pass
356
+
357
+
358
+ @dc
359
+ class NotIn(cmpop):
360
+ pass
361
+
362
+
363
+ @dc
364
+ class arg(AST):
365
+ # lineno: int
366
+ # col_offset: int
367
+ # end_lineno: int | None
368
+ # end_col_offset: int | None
369
+ arg: str
370
+ annotation: expr | None
371
+ type_comment: str | None
372
+
373
+
374
+ @dc
375
+ class keyword(AST):
376
+ # lineno: int
377
+ # col_offset: int
378
+ # end_lineno: int | None
379
+ # end_col_offset: int | None
380
+ arg: str | None
381
+ value: expr
382
+
383
+
384
+ @dc
385
+ class arguments(AST):
386
+ posonlyargs: list[arg]
387
+ args: list[arg]
388
+ vararg: arg | None
389
+ kwonlyargs: list[arg]
390
+ kw_defaults: list[expr | None]
391
+ kwarg: arg | None
392
+ defaults: list[expr]
393
+
394
+
395
+ @dc
396
+ class Lambda(expr):
397
+ args: arguments
398
+ body: expr
399
+
400
+
401
+ @dc
402
+ class type_param(AST):
403
+ # lineno: int
404
+ # col_offset: int
405
+ # end_lineno: int
406
+ # end_col_offset: int
407
+ pass
408
+
409
+
410
+ @dc
411
+ class TypeVar(type_param):
412
+ name: str
413
+ bound: expr | None
414
+ default_value: expr | None
415
+
416
+
417
+ @dc
418
+ class ParamSpec(type_param):
419
+ name: str
420
+ default_value: expr | None
421
+
422
+
423
+ @dc
424
+ class TypeVarTuple(type_param):
425
+ name: str
426
+ default_value: expr | None
427
+
428
+
429
+ @dc
430
+ class FunctionDef(stmt):
431
+ name: str
432
+ args: arguments
433
+ body: list[stmt]
434
+ decorator_list: list[expr]
435
+ returns: expr | None
436
+ type_comment: str | None
437
+ type_params: list[type_param]
438
+
439
+
440
+ @dc
441
+ class alias(AST):
442
+ name: str
443
+ asname: str | None
444
+ # lineno: int
445
+ # col_offset: int
446
+ # end_lineno: int | None
447
+ # end_col_offset: int | None
448
+
449
+
450
+ @dc
451
+ class Assert(stmt):
452
+ test: expr
453
+ msg: expr | None
454
+
455
+
456
+ @dc
457
+ class Import(stmt):
458
+ names: list[alias]
459
+
460
+
461
+ @dc
462
+ class ImportFrom(stmt):
463
+ module: str | None
464
+ names: list[alias]
465
+ level: int
466
+
467
+
468
+ @dc
469
+ class Compare(expr):
470
+ left: expr
471
+ ops: list[cmpop]
472
+ comparators: list[expr]
File without changes
@@ -0,0 +1,66 @@
1
+ #!/usr/bin/env python
2
+ """CLI tool to convert OCaml AST JSON to Python source code."""
3
+
4
+ import sys
5
+ from pathlib import Path
6
+ from typing import Annotated
7
+
8
+ import typer
9
+ from imandrax_codegen.ast_deserialize import stmts_of_json
10
+ from imandrax_codegen.unparse import unparse
11
+
12
+ app = typer.Typer()
13
+
14
+
15
+ @app.command(name='code-of-ocaml-ast')
16
+ def code_of_ocaml_ast(
17
+ input_file: Annotated[
18
+ str,
19
+ typer.Argument(help="Input JSON file (from OCaml yojson), or '-' for stdin"),
20
+ ],
21
+ output: Annotated[
22
+ str | None,
23
+ typer.Option(
24
+ '-o',
25
+ '--output',
26
+ help='Output Python file (writes to stdout if not provided)',
27
+ ),
28
+ ] = None,
29
+ include_real_to_float_alias: Annotated[
30
+ bool,
31
+ typer.Option(
32
+ '--include-real-to-float-alias', help='Include real to float alias'
33
+ ),
34
+ ] = False,
35
+ ) -> None:
36
+ """Convert OCaml AST JSON to Python source code."""
37
+ # Read and deserialize
38
+ if input_file == '-':
39
+ json_str = sys.stdin.read()
40
+ else:
41
+ with Path(input_file).open() as f:
42
+ json_str = f.read()
43
+
44
+ if not json_str:
45
+ typer.echo('imandrax_codegen error: Input is empty', err=True)
46
+ raise typer.Exit(code=1)
47
+
48
+ stmts = stmts_of_json(json_str)
49
+
50
+ # Generate Python code
51
+ python_code = unparse(
52
+ stmts,
53
+ alias_real_to_float=include_real_to_float_alias,
54
+ )
55
+
56
+ # Write output
57
+ if output:
58
+ with Path(output).open('w') as f:
59
+ f.write(python_code)
60
+ f.write('\n')
61
+ else:
62
+ typer.echo(python_code)
63
+
64
+
65
+ if __name__ == '__main__':
66
+ app()
@@ -0,0 +1,11 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class Waitlist:
6
+ arg0: int
7
+ arg1: bool
8
+
9
+
10
+ Status = Waitlist
11
+ w = Status(2, True)
@@ -0,0 +1,112 @@
1
+ __test__ = False # this is not a test
2
+ import os
3
+ import re
4
+ from pathlib import Path
5
+ from typing import Any, cast
6
+
7
+ import imandrax_codegen.ast_types as ast_types
8
+ from imandrax_api import Client, url_dev, url_prod # noqa: F401
9
+ from imandrax_api.bindings.artmsg_pb2 import Art as PbArt
10
+ from imandrax_api_models import Art, DecomposeRes, EvalRes # noqa: F401, RUF100
11
+ from imandrax_api_models.client import ImandraXClient
12
+
13
+ from .art_parse import ast_of_art
14
+
15
+ curr_dir = Path(__file__).parent
16
+
17
+
18
+ def get_fun_arg_types(fun_name: str, iml: str, c: ImandraXClient) -> list[str] | None:
19
+ """Get the argument types of a function."""
20
+ tc_res = c.typecheck(iml)
21
+ name_ty_map = {ty.name: ty.ty for ty in tc_res.types}
22
+ if fun_name not in name_ty_map:
23
+ return None
24
+
25
+ return list(map(lambda s: s.strip(), name_ty_map[fun_name].split('->')))
26
+
27
+
28
+ def extract_type_decl_names(iml_code: str) -> list[str]:
29
+ """
30
+ Extract all type definition names from OCaml code using regex.
31
+
32
+ Args:
33
+ ocaml_code: String containing OCaml code
34
+
35
+ Returns:
36
+ List of type names defined in the code
37
+
38
+ Examples:
39
+ >>> code = 'type direction = North | South'
40
+ >>> extract_ocaml_type_names(code)
41
+ ['direction']
42
+ """
43
+ # Pattern matches: "type" or "and" keyword followed by optional type parameters, then type name
44
+ # Handles both regular types and recursive types (type ... and ...)
45
+ # Also handles parameterized types:
46
+ # - Single param without parens: type 'a option
47
+ # - Multi param with parens: type ('a, 'b) container
48
+ # - Wildcard param: type _ expr (GADTs)
49
+ pattern = r'\b(?:type|and)\s+(?:(?:\([^)]+\)|\'[a-z_][a-zA-Z0-9_]*|_)\s+)?([a-z_][a-zA-Z0-9_]*(?:\s*,\s*[a-z_][a-zA-Z0-9_]*)*)'
50
+
51
+ matches = re.finditer(pattern, iml_code)
52
+ type_names: list[str] = []
53
+
54
+ for match in matches:
55
+ # Extract the captured group (type name(s))
56
+ names = match.group(1)
57
+ # Split by comma in case of mutually recursive types: type t1, t2 = ...
58
+ for name in names.split(','):
59
+ type_names.append(name.strip())
60
+
61
+ return type_names
62
+
63
+
64
+ # Main
65
+ # ====================
66
+
67
+
68
+ def gen_test_cases(
69
+ iml: str,
70
+ decomp_name: str,
71
+ other_decomp_kwargs: dict[str, Any] | None = None,
72
+ ) -> list[ast_types.stmt]:
73
+ other_decomp_kwargs = other_decomp_kwargs or {}
74
+
75
+ c = ImandraXClient(
76
+ auth_token=os.environ['IMANDRAX_API_KEY'],
77
+ # url=url_dev,
78
+ url=url_prod,
79
+ )
80
+
81
+ # Eval IML
82
+ eval_res: EvalRes = c.eval_src(iml)
83
+ if eval_res.success is not True:
84
+ error_msgs = [repr(err.msg) for err in eval_res.errors]
85
+ raise ValueError(f'Failed to evaluate source code: {error_msgs}')
86
+
87
+ # TODO: it's fixed. We should revert this change
88
+ # decomp_res: DecomposeRes = c.decompose(decomp_name, **other_decomp_kwargs)
89
+ # decomp_art = decomp_res.artifact
90
+ # assert decomp_art, 'No artifact returned from decompose'
91
+ # The decoding of fun-decomp artifact is broken, we fallback to the naive
92
+ # API client which does not have region extraction
93
+ decomp_res_proto = Client.decompose(c, decomp_name, **other_decomp_kwargs)
94
+ decomp_art = Art.model_validate(cast(PbArt, decomp_res_proto.artifact)) # type: ignore[reportUnknownMemberType]
95
+ assert decomp_art, 'No artifact returned from decompose'
96
+
97
+ arg_types: list[str] = extract_type_decl_names(iml)
98
+
99
+ # Type declarations
100
+ decls = c.get_decls(arg_types)
101
+ type_def_stmts_by_decl = [
102
+ ast_of_art(decl.artifact, mode='decl') for decl in decls.decls
103
+ ]
104
+ type_def_stmts = [stmt for stmts in type_def_stmts_by_decl for stmt in stmts]
105
+
106
+ # Test function definitions
107
+ test_def_stmts = ast_of_art(decomp_art, mode='fun-decomp')
108
+
109
+ return [
110
+ *type_def_stmts,
111
+ *test_def_stmts,
112
+ ]
@@ -0,0 +1,188 @@
1
+ """Convert custom AST to Python stdlib AST and unparse to source code."""
2
+
3
+ import ast as stdlib_ast
4
+ import subprocess
5
+ from typing import Any, Final, cast
6
+
7
+ from ruff.__main__ import find_ruff_bin
8
+
9
+ from . import ast_types as custom_ast
10
+
11
+ ruff_bin = find_ruff_bin()
12
+
13
+
14
+ OPTION_LIB_SRC: Final[str] = """\
15
+ T = TypeVar('T')
16
+
17
+
18
+ @dataclass
19
+ class Some(Generic[T]):
20
+ value: T
21
+
22
+ option: TypeAlias = Some[T] | None
23
+ """
24
+
25
+
26
+ def format_code(code: str) -> str:
27
+ """Format Python code using ruff."""
28
+
29
+ # Sort imports
30
+ code = fix_ruff_check(code, ['I001'])
31
+
32
+ try:
33
+ result = subprocess.run(
34
+ [ruff_bin, 'format', '-'],
35
+ input=code,
36
+ capture_output=True,
37
+ text=True,
38
+ check=True,
39
+ )
40
+ return result.stdout
41
+ except (subprocess.CalledProcessError, FileNotFoundError):
42
+ # If ruff fails or is not found, return original code
43
+ return code
44
+
45
+
46
+ def fix_ruff_check(code: str, rules: list[str]) -> str:
47
+ """Lint and fix code with ruff for the given rules."""
48
+
49
+ if not rules:
50
+ return code
51
+
52
+ ruff_bin = find_ruff_bin()
53
+
54
+ try:
55
+ result = subprocess.run(
56
+ [ruff_bin, 'check', '--select', ','.join(rules), '--fix', '-'],
57
+ check=False,
58
+ input=code,
59
+ capture_output=True,
60
+ text=True,
61
+ )
62
+ # ruff check outputs fixed code to stdout
63
+ return result.stdout if result.stdout else code
64
+ except (subprocess.CalledProcessError, FileNotFoundError):
65
+ # If ruff fails or is not found, return original code
66
+ return code
67
+
68
+
69
+ def remove_unused_import(code: str) -> str:
70
+ """Remove unused imports from code."""
71
+ return fix_ruff_check(code, ['F401'])
72
+
73
+
74
+ def to_stdlib(node: Any) -> Any:
75
+ """Recursively convert custom AST node to stdlib AST node."""
76
+ if node is None:
77
+ return None
78
+
79
+ # Handle primitive types
80
+ if isinstance(node, (str, int, float, bool, bytes)):
81
+ return node
82
+
83
+ # Handle lists
84
+ if isinstance(node, list):
85
+ node_list = cast(list[Any], node)
86
+ result: list[Any] = []
87
+ for item in node_list:
88
+ result.append(to_stdlib(item))
89
+ return result
90
+
91
+ # Handle custom AST nodes - get the corresponding stdlib class by name
92
+ class_name = node.__class__.__name__
93
+ stdlib_class = getattr(stdlib_ast, class_name, None)
94
+
95
+ if stdlib_class is None:
96
+ raise ValueError(f'No stdlib AST class found for: {class_name}')
97
+
98
+ # Get all fields from the custom node and recursively convert them
99
+ kwargs: dict[str, Any] = {}
100
+ for field_name, field_value in node.__dict__.items():
101
+ kwargs[field_name] = to_stdlib(field_value)
102
+
103
+ return stdlib_class(**kwargs)
104
+
105
+
106
+ def unparse(
107
+ nodes: list[custom_ast.stmt],
108
+ alias_real_to_float: bool = False,
109
+ # TODO: add a config field?
110
+ # - [x] whether to alias `real` to `float` or not
111
+ # - [ ] alternatively: use Decimal instead of float
112
+ # - the python version to use: 3.12+ or not
113
+ # - this determines the type definition syntax
114
+ # # 3.12+
115
+ # class Pair[A, B]:
116
+ # first: A
117
+ # second: B
118
+ # # ----
119
+ # # 3.11-
120
+ # A = TypeVar('A')
121
+ # B = TypeVar('B')
122
+ # class Pair(Generic[A, B]):
123
+ # first: A
124
+ # second: B
125
+ # TODO: use ruff upgrade and --target-version parameter
126
+ ) -> str:
127
+ """Convert custom AST to Python source code using stdlib ast.unparse."""
128
+ stdlib_stmts: list[stdlib_ast.stmt] = to_stdlib(nodes)
129
+
130
+ def mk_ast(src: str) -> list[stdlib_ast.stmt]:
131
+ return stdlib_ast.parse(src).body
132
+
133
+ def gen_code(stmts: list[stdlib_ast.stmt]) -> str:
134
+ """Generate Python source code from a list of AST statements."""
135
+ module = stdlib_ast.Module(body=stmts, type_ignores=[])
136
+ stdlib_ast.fix_missing_locations(module)
137
+ return stdlib_ast.unparse(module)
138
+
139
+ future_annotations_import = stdlib_ast.ImportFrom(
140
+ module='__future__',
141
+ names=[stdlib_ast.alias(name='annotations', asname=None)],
142
+ level=0,
143
+ )
144
+ dataclass_import = stdlib_ast.ImportFrom(
145
+ module='dataclasses',
146
+ names=[stdlib_ast.alias(name='dataclass', asname=None)],
147
+ level=0,
148
+ )
149
+ typing_import = stdlib_ast.ImportFrom(
150
+ module='typing',
151
+ names=[
152
+ stdlib_ast.alias(name='TypeVar', asname=None),
153
+ stdlib_ast.alias(name='Generic', asname=None),
154
+ stdlib_ast.alias(name='TypeAlias', asname=None),
155
+ ],
156
+ level=0,
157
+ )
158
+ option_lib_import = mk_ast('from imandrax_option_lib import option, Some')
159
+ option_lib_definition: list[stdlib_ast.stmt] = mk_ast(OPTION_LIB_SRC)
160
+ alias_real: list[stdlib_ast.stmt] = mk_ast('real = float')
161
+
162
+ body = [
163
+ future_annotations_import,
164
+ dataclass_import,
165
+ typing_import,
166
+ *option_lib_import,
167
+ *(alias_real if alias_real_to_float else []),
168
+ *stdlib_stmts,
169
+ ]
170
+
171
+ code = gen_code(body)
172
+
173
+ code = remove_unused_import(code)
174
+
175
+ # After removing unused imports, if "from imandrax_option_lib" is still present,
176
+ # it means that the option_lib definition is needed
177
+ if 'from imandrax_option_lib' in code:
178
+ body = [
179
+ future_annotations_import,
180
+ dataclass_import,
181
+ typing_import,
182
+ *option_lib_definition,
183
+ *(alias_real if alias_real_to_float else []),
184
+ *stdlib_stmts,
185
+ ]
186
+ code = gen_code(body)
187
+
188
+ return format_code(code)
@@ -0,0 +1,21 @@
1
+ Metadata-Version: 2.4
2
+ Name: imandrax-codegen
3
+ Version: 18.1.1
4
+ Summary: Code generator for ImandraX artifact
5
+ Author-email: hongyu <hongyu@imandra.ai>
6
+ Requires-Python: >=3.12
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: devtools>=0.12.2
9
+ Requires-Dist: dotenv>=0.9.9
10
+ Requires-Dist: imandrax-api-models>=18.0.0
11
+ Requires-Dist: imandrax-api[async]>=0.18.0.1
12
+ Requires-Dist: iml-query>=0.3.4
13
+ Requires-Dist: pydantic>=2.12.3
14
+ Requires-Dist: pyyaml>=6.0.3
15
+ Requires-Dist: rich>=14.2.0
16
+ Requires-Dist: ruff>=0.14.3
17
+ Requires-Dist: typer>=0.21.0
18
+
19
+ # ImandraX Code Generator
20
+
21
+ Code generator for ImandraX artifact
@@ -0,0 +1,15 @@
1
+ imandrax_codegen/__main__.py,sha256=8MRtWTKcBa6ip1mCm0suZSrawTHcGgXj5M7WrZ6gTOo,1096
2
+ imandrax_codegen/art_parse.exe,sha256=C-KBJWPejMYFseJel-h6LQoSjhZlMkL-k9Us6BMJk-Q,11236424
3
+ imandrax_codegen/art_parse.py,sha256=tDj7_FJjpl1uUMx056Y5WAUVHM3hoBD-0uW60GPD5PY,3244
4
+ imandrax_codegen/ast_deserialize.py,sha256=n_Mjpb3qMzXKFjHt4s4HvTOmfDt9trcuWFkvN8iak6c,3329
5
+ imandrax_codegen/ast_types.py,sha256=xMg6qKQtLO7-iOEP5sDlUFWQrw0xre8y6QWpSZEynxE,5320
6
+ imandrax_codegen/output.py,sha256=BSr9xxYau9acHLDYaIt2eJcoQ5sGbzmtcs1859srvpA,132
7
+ imandrax_codegen/test_gen.py,sha256=pyx6Qk0zFjzJuYipyk9gdnNcbWz00N8vZ7Qdoy6FkNM,3799
8
+ imandrax_codegen/unparse.py,sha256=ejfYx-ttVIhN3aG9vHlzRDy8q3jHmCsjDWNYq6xxdPI,5469
9
+ imandrax_codegen/code_of_ast/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ imandrax_codegen/code_of_ast/__main__.py,sha256=fAsdippqbtGvMUqOChRc--EZbfFkKSgdNjdbWdOXzTo,1634
11
+ imandrax_codegen-18.1.1.dist-info/METADATA,sha256=7L1duYkrRYkY2jZ0Nu0dowWuXvrhYtsz9SDKNqH62aM,604
12
+ imandrax_codegen-18.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
13
+ imandrax_codegen-18.1.1.dist-info/entry_points.txt,sha256=4dag2GtGgO2COW5g6OvQhzDQaWe79ysyQgcJLNOBwOw,67
14
+ imandrax_codegen-18.1.1.dist-info/top_level.txt,sha256=DSAK5XzSRhi4WC4SloQIsDPNK9_B6bKVmnxukfyCozc,17
15
+ imandrax_codegen-18.1.1.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ imandrax-codegen = imandrax_codegen.__main__:app
@@ -0,0 +1 @@
1
+ imandrax_codegen