jaclang 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

Files changed (99) hide show
  1. jaclang/cli/cli.py +77 -29
  2. jaclang/cli/cmdreg.py +44 -0
  3. jaclang/compiler/constant.py +6 -2
  4. jaclang/compiler/jac.lark +37 -47
  5. jaclang/compiler/larkparse/jac_parser.py +2 -2
  6. jaclang/compiler/parser.py +356 -61
  7. jaclang/compiler/passes/main/__init__.py +2 -4
  8. jaclang/compiler/passes/main/def_use_pass.py +1 -4
  9. jaclang/compiler/passes/main/predynamo_pass.py +221 -0
  10. jaclang/compiler/passes/main/pyast_gen_pass.py +221 -135
  11. jaclang/compiler/passes/main/pyast_load_pass.py +54 -20
  12. jaclang/compiler/passes/main/sym_tab_build_pass.py +1 -1
  13. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym.jac +2 -0
  14. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym_test.jac +6 -0
  15. jaclang/compiler/passes/main/tests/fixtures/checker/imported_sym.jac +5 -0
  16. jaclang/compiler/passes/main/tests/fixtures/checker_arg_param_match.jac +37 -0
  17. jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +18 -0
  18. jaclang/compiler/passes/main/tests/fixtures/checker_cat_is_animal.jac +18 -0
  19. jaclang/compiler/passes/main/tests/fixtures/checker_float.jac +7 -0
  20. jaclang/compiler/passes/main/tests/fixtures/checker_param_types.jac +11 -0
  21. jaclang/compiler/passes/main/tests/fixtures/checker_self_type.jac +9 -0
  22. jaclang/compiler/passes/main/tests/fixtures/checker_sym_inherit.jac +42 -0
  23. jaclang/compiler/passes/main/tests/fixtures/predynamo_fix3.jac +43 -0
  24. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_assign.jac +13 -0
  25. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_return.jac +11 -0
  26. jaclang/compiler/passes/main/tests/test_checker_pass.py +190 -0
  27. jaclang/compiler/passes/main/tests/test_predynamo_pass.py +56 -0
  28. jaclang/compiler/passes/main/type_checker_pass.py +29 -73
  29. jaclang/compiler/passes/tool/doc_ir_gen_pass.py +302 -58
  30. jaclang/compiler/passes/tool/jac_formatter_pass.py +119 -69
  31. jaclang/compiler/passes/tool/tests/fixtures/corelib_fmt.jac +3 -3
  32. jaclang/compiler/passes/tool/tests/fixtures/general_format_checks/triple_quoted_string.jac +4 -5
  33. jaclang/compiler/passes/tool/tests/fixtures/import_fmt.jac +7 -1
  34. jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +276 -10
  35. jaclang/compiler/passes/transform.py +12 -8
  36. jaclang/compiler/program.py +19 -7
  37. jaclang/compiler/tests/fixtures/jac_import_py_files.py +4 -0
  38. jaclang/compiler/tests/fixtures/jac_module.jac +3 -0
  39. jaclang/compiler/tests/fixtures/multiple_syntax_errors.jac +10 -0
  40. jaclang/compiler/tests/fixtures/python_module.py +1 -0
  41. jaclang/compiler/tests/test_importer.py +39 -0
  42. jaclang/compiler/tests/test_parser.py +49 -0
  43. jaclang/compiler/type_system/type_evaluator.jac +959 -0
  44. jaclang/compiler/type_system/type_utils.py +246 -0
  45. jaclang/compiler/type_system/types.py +58 -2
  46. jaclang/compiler/unitree.py +102 -107
  47. jaclang/langserve/engine.jac +138 -159
  48. jaclang/langserve/server.jac +25 -1
  49. jaclang/langserve/tests/fixtures/circle.jac +3 -3
  50. jaclang/langserve/tests/fixtures/circle_err.jac +3 -3
  51. jaclang/langserve/tests/fixtures/circle_pure.test.jac +3 -3
  52. jaclang/langserve/tests/fixtures/completion_test_err.jac +10 -0
  53. jaclang/langserve/tests/server_test/circle_template.jac +80 -0
  54. jaclang/langserve/tests/server_test/glob_template.jac +4 -0
  55. jaclang/langserve/tests/server_test/test_lang_serve.py +154 -309
  56. jaclang/langserve/tests/server_test/utils.py +153 -116
  57. jaclang/langserve/tests/test_server.py +21 -84
  58. jaclang/langserve/utils.jac +12 -15
  59. jaclang/lib.py +17 -0
  60. jaclang/runtimelib/archetype.py +25 -25
  61. jaclang/runtimelib/constructs.py +2 -2
  62. jaclang/runtimelib/machine.py +63 -46
  63. jaclang/runtimelib/meta_importer.py +27 -1
  64. jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
  65. jaclang/runtimelib/tests/fixtures/savable_object.jac +2 -2
  66. jaclang/settings.py +19 -16
  67. jaclang/tests/fixtures/abc_check.jac +3 -3
  68. jaclang/tests/fixtures/arch_rel_import_creation.jac +12 -12
  69. jaclang/tests/fixtures/attr_pattern_case.jac +18 -0
  70. jaclang/tests/fixtures/chandra_bugs2.jac +3 -3
  71. jaclang/tests/fixtures/create_dynamic_archetype.jac +13 -13
  72. jaclang/tests/fixtures/funccall_genexpr.jac +7 -0
  73. jaclang/tests/fixtures/funccall_genexpr.py +5 -0
  74. jaclang/tests/fixtures/maxfail_run_test.jac +4 -4
  75. jaclang/tests/fixtures/params/param_syntax_err.jac +9 -0
  76. jaclang/tests/fixtures/params/test_complex_params.jac +42 -0
  77. jaclang/tests/fixtures/params/test_failing_kwonly.jac +207 -0
  78. jaclang/tests/fixtures/params/test_failing_posonly.jac +116 -0
  79. jaclang/tests/fixtures/params/test_failing_varargs.jac +300 -0
  80. jaclang/tests/fixtures/params/test_kwonly_params.jac +29 -0
  81. jaclang/tests/fixtures/py2jac_params.py +8 -0
  82. jaclang/tests/fixtures/run_test.jac +4 -4
  83. jaclang/tests/test_cli.py +159 -7
  84. jaclang/tests/test_language.py +213 -38
  85. jaclang/tests/test_reference.py +3 -1
  86. jaclang/utils/helpers.py +67 -6
  87. jaclang/utils/module_resolver.py +10 -0
  88. jaclang/utils/test.py +8 -0
  89. jaclang/utils/tests/test_lang_tools.py +4 -15
  90. jaclang/utils/treeprinter.py +0 -18
  91. {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/METADATA +1 -2
  92. {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/RECORD +95 -65
  93. {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/WHEEL +1 -1
  94. jaclang/compiler/passes/main/inheritance_pass.py +0 -131
  95. jaclang/compiler/type_system/type_evaluator.py +0 -560
  96. jaclang/langserve/dev_engine.jac +0 -645
  97. jaclang/langserve/dev_server.jac +0 -201
  98. /jaclang/{langserve/tests/server_test/code_test.py → tests/fixtures/py2jac_empty.py} +0 -0
  99. {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,221 @@
1
+ """Pytorch Fix Pass."""
2
+
3
+ import ast as ast3
4
+ from typing import Optional, TypeVar, cast
5
+
6
+ import jaclang.compiler.unitree as uni
7
+ from jaclang.compiler.constant import Tokens as Tok
8
+ from jaclang.compiler.passes import UniPass
9
+
10
+
11
+ T = TypeVar("T", bound=ast3.AST)
12
+
13
+
14
+ class PreDynamoPass(UniPass):
15
+ """Pre-Dynamo pass for PyTorch."""
16
+
17
+ def enter_node(self, node: uni.UniNode) -> None:
18
+ """Enter node."""
19
+ super().enter_node(node)
20
+
21
+ def exit_node(self, node: uni.UniNode) -> None:
22
+ """Exit node."""
23
+ super().exit_node(node)
24
+
25
+ def gen_name(self, node: uni.UniNode, name: Tok, value: str) -> uni.Name:
26
+ """Generate Name."""
27
+ return uni.Name(
28
+ name=name,
29
+ value=value,
30
+ orig_src=node.loc.orig_src,
31
+ col_start=node.loc.col_start,
32
+ col_end=0,
33
+ line=node.loc.first_line,
34
+ end_line=node.loc.last_line,
35
+ pos_start=0,
36
+ pos_end=0,
37
+ )
38
+
39
+ def replace_node(
40
+ self,
41
+ new_nodes: list[uni.UniNode] | uni.UniNode,
42
+ old_node: uni.UniNode,
43
+ attr: str,
44
+ ) -> None:
45
+ """Replace old node with new nodes in parent's body and kid lists."""
46
+ parent = old_node.parent
47
+ if isinstance(new_nodes, uni.UniNode):
48
+ new_nodes.parent = parent
49
+ if hasattr(parent, attr):
50
+ lst = getattr(parent, attr)
51
+ if old_node in lst:
52
+ idx = lst.index(old_node)
53
+ lst[idx] = new_nodes
54
+ if hasattr(parent, "kid") and old_node in parent.kid:
55
+ idx = parent.kid.index(old_node)
56
+ parent.kid[idx] = new_nodes
57
+ else: # list of nodes
58
+ for n in new_nodes:
59
+ n.parent = parent
60
+ if hasattr(parent, attr):
61
+ lst = getattr(parent, attr)
62
+ if old_node in lst:
63
+ idx = lst.index(old_node)
64
+ setattr(parent, attr, lst[:idx] + new_nodes + lst[idx + 1 :])
65
+ if hasattr(parent, "kid") and old_node in parent.kid:
66
+ idx = parent.kid.index(old_node)
67
+ parent.kid = parent.kid[:idx] + new_nodes + parent.kid[idx + 1 :]
68
+
69
+ def check_same_lhs(
70
+ self, assign_a: uni.UniNode, assign_b: uni.UniNode
71
+ ) -> Optional[uni.Name]:
72
+ """Return the common LHS target if both are simple assignment with same target."""
73
+ if not (
74
+ isinstance(assign_a, uni.Assignment)
75
+ and isinstance(assign_b, uni.Assignment)
76
+ ):
77
+ return None
78
+ ta, tb = assign_a.target[0], assign_b.target[0]
79
+ if not (isinstance(ta, uni.Name) and isinstance(tb, uni.Name)):
80
+ return None
81
+ if ta.value != tb.value:
82
+ return None
83
+ return ta # common target
84
+
85
+ def check_call(self, node: uni.ExprStmt) -> Optional[tuple]:
86
+ """Return (target, name, tensor_expr, kwargs) if node is target(name, tensor_expr, **kwargs)."""
87
+ if isinstance(node, uni.ExprStmt) and isinstance(node.expr, uni.FuncCall):
88
+ call = node.expr
89
+ if (
90
+ isinstance(call.target, uni.AtomTrailer)
91
+ and len(call.params) >= 2
92
+ and isinstance(call.params[0], (uni.String, uni.MultiString))
93
+ and isinstance(call.params[1], uni.Expr)
94
+ ):
95
+ name = (
96
+ call.params[0]
97
+ if isinstance(call.params[0], uni.String)
98
+ else call.params[0].strings[0]
99
+ )
100
+ tensor_expr = call.params[1]
101
+ kwargs = (
102
+ {
103
+ kw.key._sym_name: kw.value
104
+ for kw in call.params[2:]
105
+ if isinstance(kw, uni.KWPair)
106
+ }
107
+ if len(call.params) > 2
108
+ else {}
109
+ )
110
+ return (call.target, name, tensor_expr, kwargs)
111
+ return None
112
+
113
+ def exit_if_stmt(self, node: uni.IfStmt) -> None:
114
+ """Exit if statement."""
115
+ a0 = node.body[0]
116
+ new_node = None
117
+ if node.else_body:
118
+ b0 = node.else_body.body[0]
119
+ else:
120
+ return
121
+ if isinstance(a0, uni.Assignment) and isinstance(b0, uni.Assignment):
122
+ lhs = self.check_same_lhs(a0, b0)
123
+ if lhs is not None:
124
+ func_name = self.gen_name(node, Tok.NAME, "torch")
125
+ attr_name = self.gen_name(node, Tok.NAME, "where")
126
+ target = uni.AtomTrailer(
127
+ target=func_name,
128
+ right=attr_name,
129
+ is_attr=True,
130
+ is_null_ok=False,
131
+ kid=[func_name, attr_name],
132
+ )
133
+ call = uni.FuncCall(
134
+ target=target,
135
+ params=[
136
+ node.condition,
137
+ cast(uni.Expr, a0.value),
138
+ cast(uni.Expr, b0.value),
139
+ ],
140
+ genai_call=None,
141
+ kid=[target, node.condition, a0, b0],
142
+ )
143
+ new_node = uni.Assignment(
144
+ target=[lhs], value=call, type_tag=None, kid=[lhs, call]
145
+ )
146
+ self.replace_node(new_node, node, "body")
147
+
148
+ elif isinstance(a0, uni.ReturnStmt) and isinstance(b0, uni.ReturnStmt):
149
+ aexpr, bexpr = a0.expr, b0.expr
150
+ if aexpr is None or bexpr is None:
151
+ return
152
+ func_name = self.gen_name(node, Tok.NAME, "torch")
153
+ attr_name = self.gen_name(node, Tok.NAME, "where")
154
+ target = uni.AtomTrailer(
155
+ target=func_name,
156
+ right=attr_name,
157
+ is_attr=True,
158
+ is_null_ok=False,
159
+ kid=[func_name, attr_name],
160
+ )
161
+ call = uni.FuncCall(
162
+ target=target,
163
+ params=[node.condition, cast(uni.Expr, aexpr), cast(uni.Expr, bexpr)],
164
+ genai_call=None,
165
+ kid=[target, node.condition, a0, b0],
166
+ )
167
+ new_node = uni.ReturnStmt(expr=call, kid=[call])
168
+ self.replace_node(new_node, node, "body")
169
+
170
+ elif isinstance(a0, uni.ExprStmt) and isinstance(b0, uni.ExprStmt):
171
+ a_reg = self.check_call(a0)
172
+ b_reg = self.check_call(b0)
173
+ if a_reg is not None and b_reg is not None:
174
+ a_target, a_name, a_expr, a_kwargs = a_reg
175
+ b_target, b_name, b_expr, b_kwargs = b_reg
176
+ if a_name.value == b_name.value and set(a_kwargs.keys()) == set(
177
+ b_kwargs.keys()
178
+ ):
179
+ tmp_name = self.gen_name(node, Tok.NAME, f"__{eval(a_name.value)}")
180
+ tmp_name.py_ctx_func = ast3.Store
181
+ func_name = self.gen_name(node, Tok.NAME, "torch")
182
+ attr_name = self.gen_name(node, Tok.NAME, "where")
183
+ target = uni.AtomTrailer(
184
+ target=func_name,
185
+ right=attr_name,
186
+ is_attr=True,
187
+ is_null_ok=False,
188
+ kid=[func_name, attr_name],
189
+ )
190
+ call = uni.FuncCall(
191
+ target=target,
192
+ params=[node.condition, a_expr, b_expr],
193
+ genai_call=None,
194
+ kid=[target, node.condition, a_expr, b_expr],
195
+ )
196
+ assign_node = uni.Assignment(
197
+ target=[tmp_name],
198
+ value=call,
199
+ type_tag=None,
200
+ kid=[tmp_name, call],
201
+ )
202
+
203
+ kwargs_nodes = [
204
+ uni.KWPair(
205
+ name := self.gen_name(node, Tok.NAME, k), v, [name, v]
206
+ )
207
+ for k, v in a_kwargs.items()
208
+ ]
209
+ param_name = self.gen_name(
210
+ node, Tok.NAME, f"__{eval(a_name.value)}"
211
+ )
212
+ reg_call = uni.FuncCall(
213
+ target=a_target,
214
+ params=[a_name, param_name] + kwargs_nodes,
215
+ genai_call=None,
216
+ kid=[a_target, a_name, param_name] + kwargs_nodes,
217
+ )
218
+ reg_node = uni.ExprStmt(
219
+ expr=reg_call, in_fstring=False, kid=[reg_call]
220
+ )
221
+ self.replace_node([assign_node, reg_node], node, "body")