jaclang 0.8.7__py3-none-any.whl → 0.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jaclang might be problematic. Click here for more details.
- jaclang/cli/cli.py +13 -27
- jaclang/cli/cmdreg.py +44 -0
- jaclang/compiler/constant.py +0 -1
- jaclang/compiler/jac.lark +3 -6
- jaclang/compiler/larkparse/jac_parser.py +2 -2
- jaclang/compiler/parser.py +213 -34
- jaclang/compiler/passes/main/__init__.py +2 -4
- jaclang/compiler/passes/main/def_use_pass.py +0 -4
- jaclang/compiler/passes/main/predynamo_pass.py +221 -0
- jaclang/compiler/passes/main/pyast_gen_pass.py +70 -52
- jaclang/compiler/passes/main/pyast_load_pass.py +52 -20
- jaclang/compiler/passes/main/sym_tab_build_pass.py +1 -1
- jaclang/compiler/passes/main/tests/fixtures/checker/import_sym.jac +2 -0
- jaclang/compiler/passes/main/tests/fixtures/checker/import_sym_test.jac +6 -0
- jaclang/compiler/passes/main/tests/fixtures/checker/imported_sym.jac +5 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_arg_param_match.jac +37 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +18 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_cat_is_animal.jac +18 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_float.jac +7 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_param_types.jac +11 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_self_type.jac +9 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_sym_inherit.jac +42 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_fix3.jac +43 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_where_assign.jac +13 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_where_return.jac +11 -0
- jaclang/compiler/passes/main/tests/test_checker_pass.py +191 -0
- jaclang/compiler/passes/main/tests/test_predynamo_pass.py +57 -0
- jaclang/compiler/passes/main/type_checker_pass.py +29 -73
- jaclang/compiler/passes/tool/doc_ir_gen_pass.py +204 -44
- jaclang/compiler/passes/tool/jac_formatter_pass.py +119 -69
- jaclang/compiler/passes/tool/tests/fixtures/corelib_fmt.jac +3 -3
- jaclang/compiler/passes/tool/tests/fixtures/general_format_checks/triple_quoted_string.jac +4 -5
- jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +171 -11
- jaclang/compiler/passes/transform.py +12 -8
- jaclang/compiler/program.py +14 -6
- jaclang/compiler/tests/fixtures/jac_import_py_files.py +4 -0
- jaclang/compiler/tests/fixtures/jac_module.jac +3 -0
- jaclang/compiler/tests/fixtures/multiple_syntax_errors.jac +10 -0
- jaclang/compiler/tests/fixtures/python_module.py +1 -0
- jaclang/compiler/tests/test_importer.py +39 -0
- jaclang/compiler/tests/test_parser.py +49 -0
- jaclang/compiler/type_system/type_evaluator.py +351 -67
- jaclang/compiler/type_system/type_utils.py +246 -0
- jaclang/compiler/type_system/types.py +58 -2
- jaclang/compiler/unitree.py +79 -94
- jaclang/langserve/engine.jac +138 -159
- jaclang/langserve/server.jac +25 -1
- jaclang/langserve/tests/fixtures/circle.jac +3 -3
- jaclang/langserve/tests/fixtures/circle_err.jac +3 -3
- jaclang/langserve/tests/fixtures/circle_pure.test.jac +3 -3
- jaclang/langserve/tests/fixtures/completion_test_err.jac +10 -0
- jaclang/langserve/tests/server_test/circle_template.jac +80 -0
- jaclang/langserve/tests/server_test/glob_template.jac +4 -0
- jaclang/langserve/tests/server_test/test_lang_serve.py +154 -309
- jaclang/langserve/tests/server_test/utils.py +153 -116
- jaclang/langserve/tests/test_server.py +21 -84
- jaclang/langserve/utils.jac +12 -15
- jaclang/runtimelib/machine.py +7 -0
- jaclang/runtimelib/meta_importer.py +27 -1
- jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
- jaclang/runtimelib/tests/fixtures/savable_object.jac +2 -2
- jaclang/settings.py +18 -14
- jaclang/tests/fixtures/abc_check.jac +3 -3
- jaclang/tests/fixtures/arch_rel_import_creation.jac +12 -12
- jaclang/tests/fixtures/chandra_bugs2.jac +3 -3
- jaclang/tests/fixtures/create_dynamic_archetype.jac +13 -13
- jaclang/tests/fixtures/maxfail_run_test.jac +4 -4
- jaclang/tests/fixtures/params/param_syntax_err.jac +9 -0
- jaclang/tests/fixtures/params/test_complex_params.jac +42 -0
- jaclang/tests/fixtures/params/test_failing_kwonly.jac +207 -0
- jaclang/tests/fixtures/params/test_failing_posonly.jac +116 -0
- jaclang/tests/fixtures/params/test_failing_varargs.jac +300 -0
- jaclang/tests/fixtures/params/test_kwonly_params.jac +29 -0
- jaclang/tests/fixtures/py2jac_params.py +8 -0
- jaclang/tests/fixtures/run_test.jac +4 -4
- jaclang/tests/test_cli.py +37 -1
- jaclang/tests/test_language.py +74 -16
- jaclang/utils/helpers.py +47 -2
- jaclang/utils/module_resolver.py +10 -0
- jaclang/utils/test.py +8 -0
- jaclang/utils/treeprinter.py +0 -18
- {jaclang-0.8.7.dist-info → jaclang-0.8.8.dist-info}/METADATA +1 -2
- {jaclang-0.8.7.dist-info → jaclang-0.8.8.dist-info}/RECORD +85 -60
- {jaclang-0.8.7.dist-info → jaclang-0.8.8.dist-info}/WHEEL +1 -1
- jaclang/compiler/passes/main/inheritance_pass.py +0 -131
- jaclang/langserve/dev_engine.jac +0 -645
- jaclang/langserve/dev_server.jac +0 -201
- jaclang/langserve/tests/server_test/code_test.py +0 -0
- {jaclang-0.8.7.dist-info → jaclang-0.8.8.dist-info}/entry_points.txt +0 -0
jaclang/compiler/parser.py
CHANGED
|
@@ -5,12 +5,14 @@ from __future__ import annotations
|
|
|
5
5
|
import keyword
|
|
6
6
|
import logging
|
|
7
7
|
import os
|
|
8
|
+
import sys
|
|
8
9
|
from typing import Callable, Sequence, TYPE_CHECKING, TypeAlias, TypeVar, cast
|
|
9
10
|
|
|
10
11
|
import jaclang.compiler.unitree as uni
|
|
11
|
-
from jaclang.compiler import jac_lark as jl
|
|
12
|
+
from jaclang.compiler import TOKEN_MAP, jac_lark as jl
|
|
12
13
|
from jaclang.compiler.constant import EdgeDir, Tokens as Tok
|
|
13
14
|
from jaclang.compiler.passes.main import Transform
|
|
15
|
+
from jaclang.utils.helpers import ANSIColors
|
|
14
16
|
from jaclang.vendor.lark import Lark, Transformer, Tree, logger
|
|
15
17
|
|
|
16
18
|
if TYPE_CHECKING:
|
|
@@ -39,22 +41,17 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
39
41
|
tree, comments = JacParser.parse(ir_in.value, on_error=self.error_callback)
|
|
40
42
|
mod = JacParser.TreeToAST(parser=self).transform(tree)
|
|
41
43
|
ir_in.comments = [self.proc_comment(i, mod) for i in comments]
|
|
42
|
-
if isinstance(mod, uni.Module):
|
|
43
|
-
self.ir_out = mod
|
|
44
|
-
return mod
|
|
45
|
-
else:
|
|
44
|
+
if not isinstance(mod, uni.Module):
|
|
46
45
|
raise self.ice()
|
|
46
|
+
if len(self.errors_had) != 0:
|
|
47
|
+
mod.has_syntax_errors = True
|
|
48
|
+
self.report_errors()
|
|
49
|
+
self.ir_out = mod
|
|
50
|
+
return mod
|
|
47
51
|
except jl.UnexpectedInput as e:
|
|
48
|
-
catch_error =
|
|
49
|
-
|
|
50
|
-
catch_error.line_no = e.line
|
|
51
|
-
catch_error.end_line = e.line
|
|
52
|
-
catch_error.c_start = e.column
|
|
53
|
-
catch_error.c_end = e.column + 1
|
|
54
|
-
catch_error.pos_start = e.pos_in_stream or 0
|
|
55
|
-
catch_error.pos_end = catch_error.pos_start + 1
|
|
52
|
+
catch_error = self.error_to_token(e)
|
|
53
|
+
error_msg = self.error_to_message(e)
|
|
56
54
|
|
|
57
|
-
error_msg = "Syntax Error"
|
|
58
55
|
if len(e.args) >= 1 and isinstance(e.args[0], str):
|
|
59
56
|
error_msg += e.args[0]
|
|
60
57
|
self.log_error(error_msg, node_override=catch_error)
|
|
@@ -62,7 +59,12 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
62
59
|
except Exception as e:
|
|
63
60
|
raise e
|
|
64
61
|
|
|
65
|
-
|
|
62
|
+
# If we reach here, there was a syntax error, mark the module as such
|
|
63
|
+
# and report errors.
|
|
64
|
+
self.report_errors()
|
|
65
|
+
mod = uni.Module.make_stub(inject_src=ir_in)
|
|
66
|
+
mod.has_syntax_errors = True
|
|
67
|
+
return mod
|
|
66
68
|
|
|
67
69
|
@staticmethod
|
|
68
70
|
def proc_comment(token: jl.Token, mod: uni.UniNode) -> uni.CommentToken:
|
|
@@ -80,10 +82,109 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
80
82
|
kid=[],
|
|
81
83
|
)
|
|
82
84
|
|
|
85
|
+
_MISSING_TOKENS = [
|
|
86
|
+
Tok.SEMI,
|
|
87
|
+
Tok.COMMA,
|
|
88
|
+
Tok.COLON,
|
|
89
|
+
Tok.RPAREN,
|
|
90
|
+
Tok.RBRACE,
|
|
91
|
+
Tok.RSQUARE,
|
|
92
|
+
Tok.RETURN_HINT,
|
|
93
|
+
]
|
|
94
|
+
|
|
83
95
|
def error_callback(self, e: jl.UnexpectedInput) -> bool:
|
|
84
96
|
"""Handle error."""
|
|
97
|
+
iparser = e.interactive_parser
|
|
98
|
+
|
|
99
|
+
def try_feed_missing_token(iparser: jl.InteractiveParser) -> Tok | None:
|
|
100
|
+
"""Feed a missing token to the parser."""
|
|
101
|
+
# If any of the below token is missing, insert them and continue parsing.
|
|
102
|
+
accepts = iparser.accepts()
|
|
103
|
+
for tok in JacParser._MISSING_TOKENS:
|
|
104
|
+
if tok.name in accepts:
|
|
105
|
+
iparser.feed_token(jl.Token(tok.name, TOKEN_MAP[tok.name]))
|
|
106
|
+
return tok
|
|
107
|
+
return None
|
|
108
|
+
|
|
109
|
+
def feed_current_token(iparser: jl.InteractiveParser, tok: jl.Token) -> bool:
|
|
110
|
+
"""Feed the current token to the parser."""
|
|
111
|
+
while tok.type not in iparser.accepts():
|
|
112
|
+
if not try_feed_missing_token(iparser):
|
|
113
|
+
return False
|
|
114
|
+
iparser.feed_token(tok)
|
|
115
|
+
return True
|
|
116
|
+
|
|
117
|
+
if isinstance(e, jl.UnexpectedToken):
|
|
118
|
+
# If last token is DOT and we expect a NAME, insert a NAME token
|
|
119
|
+
last_tok: jl.Token | None = (
|
|
120
|
+
e.token_history[-1]
|
|
121
|
+
if e.token_history and len(e.token_history) >= 1
|
|
122
|
+
else None
|
|
123
|
+
)
|
|
124
|
+
if (
|
|
125
|
+
last_tok
|
|
126
|
+
and last_tok.type == Tok.DOT.name
|
|
127
|
+
and (Tok.NAME.name in e.accepts)
|
|
128
|
+
):
|
|
129
|
+
self.log_error("Incomplete member access", self.error_to_token(e))
|
|
130
|
+
iparser.feed_token(jl.Token(Tok.NAME.name, "recover_name_token"))
|
|
131
|
+
return feed_current_token(iparser, e.token)
|
|
132
|
+
|
|
133
|
+
# We're calling try_feed_missing_token twice here because the first missing
|
|
134
|
+
# will be reported as such and we don't for the consequent missing token.
|
|
135
|
+
if tk := try_feed_missing_token(iparser):
|
|
136
|
+
self.log_error(f"Missing {tk.name}", self.error_to_token(e))
|
|
137
|
+
return feed_current_token(iparser, e.token)
|
|
138
|
+
|
|
139
|
+
# Ignore unexpected tokens and continue parsing till we reach a known state.
|
|
140
|
+
self.log_error(
|
|
141
|
+
f"Unexpected token '{e.token.value}'", self.error_to_token(e)
|
|
142
|
+
)
|
|
143
|
+
return True
|
|
144
|
+
|
|
85
145
|
return False
|
|
86
146
|
|
|
147
|
+
def error_to_message(self, e: jl.UnexpectedInput) -> str:
|
|
148
|
+
"""Return an error message based on the exception."""
|
|
149
|
+
# TODO: Match more specific errors with lark's example based matching.
|
|
150
|
+
# Reference: https://github.com/lark-parser/lark/blob/master/examples/advanced/error_reporting_lalr.py
|
|
151
|
+
# e.match_examples()
|
|
152
|
+
if isinstance(e, jl.UnexpectedToken):
|
|
153
|
+
return f"Unexpected token '{e.token.value}'"
|
|
154
|
+
return "Syntax Error"
|
|
155
|
+
|
|
156
|
+
def error_to_token(self, e: jl.UnexpectedInput) -> uni.Token:
|
|
157
|
+
"""Convert error to token."""
|
|
158
|
+
catch_error = uni.EmptyToken()
|
|
159
|
+
catch_error.orig_src = self.ir_in
|
|
160
|
+
catch_error.line_no = e.line
|
|
161
|
+
catch_error.end_line = e.line
|
|
162
|
+
catch_error.c_start = e.column
|
|
163
|
+
catch_error.pos_start = e.pos_in_stream or 0
|
|
164
|
+
if isinstance(e, jl.UnexpectedToken) and e.token:
|
|
165
|
+
catch_error.c_end = e.token.end_column or (e.column + 1)
|
|
166
|
+
catch_error.pos_end = e.token.end_pos or (catch_error.pos_start + 1)
|
|
167
|
+
else:
|
|
168
|
+
catch_error.c_end = e.column + 1
|
|
169
|
+
catch_error.pos_end = catch_error.pos_start + 1
|
|
170
|
+
return catch_error
|
|
171
|
+
|
|
172
|
+
def report_errors(self, *, colors: bool = True) -> None:
|
|
173
|
+
"""Report errors to the user."""
|
|
174
|
+
# TODO: Write a better IO system.
|
|
175
|
+
# NOTE: Currently it writes all the errors to stderr cause LSP JsonRPC uses stdout for IPC.
|
|
176
|
+
if not sys.stderr.isatty():
|
|
177
|
+
# FIXME: If we're outputting to a file (pipe, redirection, etc) other
|
|
178
|
+
# than a terminal we disable colors however we should be able to force
|
|
179
|
+
# colors with a configuration.
|
|
180
|
+
colors = False
|
|
181
|
+
for alrt in self.errors_had:
|
|
182
|
+
error_label = (
|
|
183
|
+
"Error:" if not colors else f"{ANSIColors.RED}Error:{ANSIColors.END}"
|
|
184
|
+
)
|
|
185
|
+
print(error_label, end=" ", file=sys.stderr)
|
|
186
|
+
print(alrt.pretty_print(colors=colors), file=sys.stderr)
|
|
187
|
+
|
|
87
188
|
@staticmethod
|
|
88
189
|
def _comment_callback(comment: jl.Token) -> None:
|
|
89
190
|
JacParser.comment_cache.append(comment)
|
|
@@ -848,25 +949,104 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
848
949
|
if self.match_token(Tok.RETURN_HINT):
|
|
849
950
|
return_spec = self.consume(uni.Expr)
|
|
850
951
|
return uni.FuncSignature(
|
|
952
|
+
posonly_params=[],
|
|
851
953
|
params=[],
|
|
954
|
+
varargs=None,
|
|
955
|
+
kwonlyargs=[],
|
|
956
|
+
kwargs=None,
|
|
852
957
|
return_type=return_spec,
|
|
853
958
|
kid=self.flat_cur_nodes,
|
|
854
959
|
)
|
|
855
960
|
# Otherwise, parse the traditional parameter list form
|
|
856
961
|
else:
|
|
857
962
|
self.consume_token(Tok.LPAREN)
|
|
858
|
-
|
|
963
|
+
all_params = self.match(list) or []
|
|
964
|
+
posonly_params, params, varargs, kwonlyargs, kwargs = (
|
|
965
|
+
self._parse_parameter_categories(all_params)
|
|
966
|
+
)
|
|
859
967
|
self.consume_token(Tok.RPAREN)
|
|
860
968
|
if self.match_token(Tok.RETURN_HINT):
|
|
861
969
|
return_spec = self.consume(uni.Expr)
|
|
862
970
|
return uni.FuncSignature(
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
971
|
+
posonly_params=posonly_params,
|
|
972
|
+
params=params,
|
|
973
|
+
varargs=varargs,
|
|
974
|
+
kwonlyargs=kwonlyargs,
|
|
975
|
+
kwargs=kwargs,
|
|
866
976
|
return_type=return_spec,
|
|
867
977
|
kid=self.flat_cur_nodes,
|
|
868
978
|
)
|
|
869
979
|
|
|
980
|
+
def _parse_parameter_categories(self, all_params: list[uni.UniNode]) -> tuple[
|
|
981
|
+
list[uni.ParamVar],
|
|
982
|
+
list[uni.ParamVar],
|
|
983
|
+
uni.ParamVar | None,
|
|
984
|
+
list[uni.ParamVar],
|
|
985
|
+
uni.ParamVar | None,
|
|
986
|
+
]:
|
|
987
|
+
posonly_params = []
|
|
988
|
+
params = []
|
|
989
|
+
varargs = None
|
|
990
|
+
kwonlyargs = []
|
|
991
|
+
kwargs = None
|
|
992
|
+
|
|
993
|
+
# Initial state determination
|
|
994
|
+
cur_state = "positional"
|
|
995
|
+
for param in all_params:
|
|
996
|
+
if isinstance(param, uni.Token) and param.name == Tok.DIV:
|
|
997
|
+
cur_state = "posonly"
|
|
998
|
+
break
|
|
999
|
+
|
|
1000
|
+
for cur_nd in all_params:
|
|
1001
|
+
cur_state = self._update_parameter_state(cur_nd, cur_state)
|
|
1002
|
+
if isinstance(cur_nd, uni.ParamVar):
|
|
1003
|
+
if cur_state == "positional":
|
|
1004
|
+
cur_nd.param_kind = uni.ParamKind.NORMAL
|
|
1005
|
+
params.append(cur_nd)
|
|
1006
|
+
elif cur_state == "posonly":
|
|
1007
|
+
cur_nd.param_kind = uni.ParamKind.POSONLY
|
|
1008
|
+
posonly_params.append(cur_nd)
|
|
1009
|
+
elif cur_state == "varargs":
|
|
1010
|
+
cur_nd.param_kind = uni.ParamKind.VARARG
|
|
1011
|
+
varargs = cur_nd
|
|
1012
|
+
cur_state = "keyword_only"
|
|
1013
|
+
elif cur_state == "keyword_only":
|
|
1014
|
+
cur_nd.param_kind = uni.ParamKind.KWONLY
|
|
1015
|
+
kwonlyargs.append(cur_nd)
|
|
1016
|
+
elif cur_state == "kwargs":
|
|
1017
|
+
cur_nd.param_kind = uni.ParamKind.KWARG
|
|
1018
|
+
kwargs = cur_nd
|
|
1019
|
+
else:
|
|
1020
|
+
raise self.ice()
|
|
1021
|
+
|
|
1022
|
+
return posonly_params, params, varargs, kwonlyargs, kwargs
|
|
1023
|
+
|
|
1024
|
+
def _update_parameter_state(self, cur_nd: uni.UniNode, cur_state: str) -> str:
|
|
1025
|
+
if isinstance(cur_nd, uni.Token):
|
|
1026
|
+
if cur_nd.name == Tok.DIV:
|
|
1027
|
+
if cur_state in ["keyword_only", "kwargs", "positional"]:
|
|
1028
|
+
self.parse_ref.log_error(
|
|
1029
|
+
"Invalid syntax in function parameters: '/' cannot appear after '*' or '**'.",
|
|
1030
|
+
node_override=cur_nd,
|
|
1031
|
+
)
|
|
1032
|
+
return "positional"
|
|
1033
|
+
elif cur_nd.name == Tok.STAR_MUL:
|
|
1034
|
+
if cur_state in ["keyword_only", "kwargs"]:
|
|
1035
|
+
self.parse_ref.log_error(
|
|
1036
|
+
"Invalid syntax in function parameters: '*' cannot appear after '**'.",
|
|
1037
|
+
node_override=cur_nd,
|
|
1038
|
+
)
|
|
1039
|
+
return "keyword_only"
|
|
1040
|
+
elif cur_nd.name == Tok.COMMA:
|
|
1041
|
+
return cur_state
|
|
1042
|
+
|
|
1043
|
+
elif isinstance(cur_nd, uni.ParamVar):
|
|
1044
|
+
if cur_nd.is_vararg:
|
|
1045
|
+
return "varargs"
|
|
1046
|
+
if cur_nd.is_kwargs:
|
|
1047
|
+
return "kwargs"
|
|
1048
|
+
return cur_state
|
|
1049
|
+
|
|
870
1050
|
def func_decl_params(self, _: None) -> list[uni.UniNode]:
|
|
871
1051
|
"""Grammar rule.
|
|
872
1052
|
|
|
@@ -874,11 +1054,17 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
874
1054
|
"""
|
|
875
1055
|
return self.cur_nodes
|
|
876
1056
|
|
|
877
|
-
def param_var(self, _: None) -> uni.ParamVar:
|
|
1057
|
+
def param_var(self, _: None) -> uni.ParamVar | uni.Token:
|
|
878
1058
|
"""Grammar rule.
|
|
879
1059
|
|
|
880
|
-
param_var: (STAR_POW | STAR_MUL)?
|
|
1060
|
+
param_var: (STAR_POW | STAR_MUL)? named_ref type_tag (EQ expression)?
|
|
1061
|
+
| DIV
|
|
1062
|
+
| STAR_MUL
|
|
881
1063
|
"""
|
|
1064
|
+
if len(self.cur_nodes) == 1 and (
|
|
1065
|
+
star_only := self.match_token(Tok.DIV) or self.match_token(Tok.STAR_MUL)
|
|
1066
|
+
):
|
|
1067
|
+
return star_only
|
|
882
1068
|
star = self.match_token(Tok.STAR_POW) or self.match_token(Tok.STAR_MUL)
|
|
883
1069
|
name = self.consume(uni.Name)
|
|
884
1070
|
type_tag = self.consume(uni.SubTag)
|
|
@@ -1028,7 +1214,6 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
1028
1214
|
| (yield_expr | KW_YIELD) SEMI
|
|
1029
1215
|
| raise_stmt SEMI
|
|
1030
1216
|
| assert_stmt SEMI
|
|
1031
|
-
| check_stmt SEMI
|
|
1032
1217
|
| assignment SEMI
|
|
1033
1218
|
| delete_stmt SEMI
|
|
1034
1219
|
| report_stmt SEMI
|
|
@@ -1307,18 +1492,6 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
1307
1492
|
kid=self.cur_nodes,
|
|
1308
1493
|
)
|
|
1309
1494
|
|
|
1310
|
-
def check_stmt(self, _: None) -> uni.CheckStmt:
|
|
1311
|
-
"""Grammar rule.
|
|
1312
|
-
|
|
1313
|
-
check_stmt: KW_CHECK expression
|
|
1314
|
-
"""
|
|
1315
|
-
self.consume_token(Tok.KW_CHECK)
|
|
1316
|
-
target = self.consume(uni.Expr)
|
|
1317
|
-
return uni.CheckStmt(
|
|
1318
|
-
target=target,
|
|
1319
|
-
kid=self.cur_nodes,
|
|
1320
|
-
)
|
|
1321
|
-
|
|
1322
1495
|
def ctrl_stmt(self, _: None) -> uni.CtrlStmt | uni.DisengageStmt:
|
|
1323
1496
|
"""Grammar rule.
|
|
1324
1497
|
|
|
@@ -1545,9 +1718,13 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
1545
1718
|
sig_kid.append(return_type)
|
|
1546
1719
|
signature = (
|
|
1547
1720
|
uni.FuncSignature(
|
|
1721
|
+
posonly_params=[],
|
|
1548
1722
|
params=(
|
|
1549
1723
|
self.extract_from_list(params, uni.ParamVar) if params else []
|
|
1550
1724
|
),
|
|
1725
|
+
varargs=None,
|
|
1726
|
+
kwonlyargs=[],
|
|
1727
|
+
kwargs=None,
|
|
1551
1728
|
return_type=return_type,
|
|
1552
1729
|
kid=sig_kid,
|
|
1553
1730
|
)
|
|
@@ -2328,6 +2505,7 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
2328
2505
|
LSQUARE (KW_NODE| KW_EDGE)? expression? (edge_op_ref (filter_compr | expression)?)+ RSQUARE
|
|
2329
2506
|
"""
|
|
2330
2507
|
self.consume_token(Tok.LSQUARE)
|
|
2508
|
+
is_async = bool(self.match_token(Tok.KW_ASYNC))
|
|
2331
2509
|
edges_only = bool(self.match_token(Tok.KW_EDGE))
|
|
2332
2510
|
self.match_token(Tok.KW_NODE)
|
|
2333
2511
|
valid_chain = []
|
|
@@ -2338,6 +2516,7 @@ class JacParser(Transform[uni.Source, uni.Module]):
|
|
|
2338
2516
|
return uni.EdgeRefTrailer(
|
|
2339
2517
|
chain=valid_chain,
|
|
2340
2518
|
edges_only=edges_only,
|
|
2519
|
+
is_async=is_async,
|
|
2341
2520
|
kid=self.cur_nodes,
|
|
2342
2521
|
)
|
|
2343
2522
|
|
|
@@ -4,7 +4,6 @@ from ..transform import Alert, Transform # noqa: I100
|
|
|
4
4
|
from .annex_pass import JacAnnexPass # noqa: I100
|
|
5
5
|
from .binder_pass import BinderPass # noqa: I100
|
|
6
6
|
from .sym_tab_build_pass import SymTabBuildPass, UniPass # noqa: I100
|
|
7
|
-
from .sym_tab_link_pass import SymTabLinkPass # noqa: I100
|
|
8
7
|
from .def_use_pass import DefUsePass # noqa: I100
|
|
9
8
|
from .sem_def_match_pass import SemDefMatchPass # noqa: I100
|
|
10
9
|
from .import_pass import JacImportDepsPass # noqa: I100
|
|
@@ -12,10 +11,10 @@ from .def_impl_match_pass import DeclImplMatchPass # noqa: I100
|
|
|
12
11
|
from .type_checker_pass import TypeCheckPass # noqa: I100
|
|
13
12
|
from .pyast_load_pass import PyastBuildPass # type: ignore # noqa: I100
|
|
14
13
|
from .pyast_gen_pass import PyastGenPass # noqa: I100
|
|
14
|
+
from .predynamo_pass import PreDynamoPass # noqa: I100
|
|
15
15
|
from .pybc_gen_pass import PyBytecodeGenPass # noqa: I100
|
|
16
16
|
from .cfg_build_pass import CFGBuildPass # noqa: I100
|
|
17
17
|
from .pyjac_ast_link_pass import PyJacAstLinkPass # noqa: I100
|
|
18
|
-
from .inheritance_pass import InheritancePass # noqa: I100
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
__all__ = [
|
|
@@ -28,14 +27,13 @@ __all__ = [
|
|
|
28
27
|
"BinderPass",
|
|
29
28
|
"TypeCheckPass",
|
|
30
29
|
"SymTabBuildPass",
|
|
31
|
-
"SymTabLinkPass",
|
|
32
30
|
"DeclImplMatchPass",
|
|
33
31
|
"DefUsePass",
|
|
34
32
|
"SemDefMatchPass",
|
|
35
33
|
"PyastBuildPass",
|
|
36
34
|
"PyastGenPass",
|
|
35
|
+
"PreDynamoPass",
|
|
37
36
|
"PyBytecodeGenPass",
|
|
38
37
|
"CFGBuildPass",
|
|
39
38
|
"PyJacAstLinkPass",
|
|
40
|
-
"InheritancePass",
|
|
41
39
|
]
|
|
@@ -30,7 +30,6 @@ class DefUsePass(UniPass):
|
|
|
30
30
|
"""Jac Ast build pass."""
|
|
31
31
|
|
|
32
32
|
def enter_archetype(self, node: uni.Archetype) -> None:
|
|
33
|
-
node.sym_tab.inherit_baseclasses_sym(node)
|
|
34
33
|
|
|
35
34
|
def inform_from_walker(node: uni.UniNode) -> None:
|
|
36
35
|
for i in (
|
|
@@ -47,9 +46,6 @@ class DefUsePass(UniPass):
|
|
|
47
46
|
if isinstance(i.body, uni.ImplDef):
|
|
48
47
|
inform_from_walker(i.body)
|
|
49
48
|
|
|
50
|
-
def enter_enum(self, node: uni.Enum) -> None:
|
|
51
|
-
node.sym_tab.inherit_baseclasses_sym(node)
|
|
52
|
-
|
|
53
49
|
def enter_type_ref(self, node: uni.TypeRef) -> None:
|
|
54
50
|
node.sym_tab.use_lookup(node)
|
|
55
51
|
|
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""Pytorch Fix Pass."""
|
|
2
|
+
|
|
3
|
+
import ast as ast3
|
|
4
|
+
from typing import Optional, TypeVar, cast
|
|
5
|
+
|
|
6
|
+
import jaclang.compiler.unitree as uni
|
|
7
|
+
from jaclang.compiler.constant import Tokens as Tok
|
|
8
|
+
from jaclang.compiler.passes import UniPass
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
T = TypeVar("T", bound=ast3.AST)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PreDynamoPass(UniPass):
|
|
15
|
+
"""Pre-Dynamo pass for PyTorch."""
|
|
16
|
+
|
|
17
|
+
def enter_node(self, node: uni.UniNode) -> None:
|
|
18
|
+
"""Enter node."""
|
|
19
|
+
super().enter_node(node)
|
|
20
|
+
|
|
21
|
+
def exit_node(self, node: uni.UniNode) -> None:
|
|
22
|
+
"""Exit node."""
|
|
23
|
+
super().exit_node(node)
|
|
24
|
+
|
|
25
|
+
def gen_name(self, node: uni.UniNode, name: Tok, value: str) -> uni.Name:
|
|
26
|
+
"""Generate Name."""
|
|
27
|
+
return uni.Name(
|
|
28
|
+
name=name,
|
|
29
|
+
value=value,
|
|
30
|
+
orig_src=node.loc.orig_src,
|
|
31
|
+
col_start=node.loc.col_start,
|
|
32
|
+
col_end=0,
|
|
33
|
+
line=node.loc.first_line,
|
|
34
|
+
end_line=node.loc.last_line,
|
|
35
|
+
pos_start=0,
|
|
36
|
+
pos_end=0,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
def replace_node(
|
|
40
|
+
self,
|
|
41
|
+
new_nodes: list[uni.UniNode] | uni.UniNode,
|
|
42
|
+
old_node: uni.UniNode,
|
|
43
|
+
attr: str,
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Replace old node with new nodes in parent's body and kid lists."""
|
|
46
|
+
parent = old_node.parent
|
|
47
|
+
if isinstance(new_nodes, uni.UniNode):
|
|
48
|
+
new_nodes.parent = parent
|
|
49
|
+
if hasattr(parent, attr):
|
|
50
|
+
lst = getattr(parent, attr)
|
|
51
|
+
if old_node in lst:
|
|
52
|
+
idx = lst.index(old_node)
|
|
53
|
+
lst[idx] = new_nodes
|
|
54
|
+
if hasattr(parent, "kid") and old_node in parent.kid:
|
|
55
|
+
idx = parent.kid.index(old_node)
|
|
56
|
+
parent.kid[idx] = new_nodes
|
|
57
|
+
else: # list of nodes
|
|
58
|
+
for n in new_nodes:
|
|
59
|
+
n.parent = parent
|
|
60
|
+
if hasattr(parent, attr):
|
|
61
|
+
lst = getattr(parent, attr)
|
|
62
|
+
if old_node in lst:
|
|
63
|
+
idx = lst.index(old_node)
|
|
64
|
+
setattr(parent, attr, lst[:idx] + new_nodes + lst[idx + 1 :])
|
|
65
|
+
if hasattr(parent, "kid") and old_node in parent.kid:
|
|
66
|
+
idx = parent.kid.index(old_node)
|
|
67
|
+
parent.kid = parent.kid[:idx] + new_nodes + parent.kid[idx + 1 :]
|
|
68
|
+
|
|
69
|
+
def check_same_lhs(
|
|
70
|
+
self, assign_a: uni.UniNode, assign_b: uni.UniNode
|
|
71
|
+
) -> Optional[uni.Name]:
|
|
72
|
+
"""Return the common LHS target if both are simple assignment with same target."""
|
|
73
|
+
if not (
|
|
74
|
+
isinstance(assign_a, uni.Assignment)
|
|
75
|
+
and isinstance(assign_b, uni.Assignment)
|
|
76
|
+
):
|
|
77
|
+
return None
|
|
78
|
+
ta, tb = assign_a.target[0], assign_b.target[0]
|
|
79
|
+
if not (isinstance(ta, uni.Name) and isinstance(tb, uni.Name)):
|
|
80
|
+
return None
|
|
81
|
+
if ta.value != tb.value:
|
|
82
|
+
return None
|
|
83
|
+
return ta # common target
|
|
84
|
+
|
|
85
|
+
def check_call(self, node: uni.ExprStmt) -> Optional[tuple]:
|
|
86
|
+
"""Return (target, name, tensor_expr, kwargs) if node is target(name, tensor_expr, **kwargs)."""
|
|
87
|
+
if isinstance(node, uni.ExprStmt) and isinstance(node.expr, uni.FuncCall):
|
|
88
|
+
call = node.expr
|
|
89
|
+
if (
|
|
90
|
+
isinstance(call.target, uni.AtomTrailer)
|
|
91
|
+
and len(call.params) >= 2
|
|
92
|
+
and isinstance(call.params[0], (uni.String, uni.MultiString))
|
|
93
|
+
and isinstance(call.params[1], uni.Expr)
|
|
94
|
+
):
|
|
95
|
+
name = (
|
|
96
|
+
call.params[0]
|
|
97
|
+
if isinstance(call.params[0], uni.String)
|
|
98
|
+
else call.params[0].strings[0]
|
|
99
|
+
)
|
|
100
|
+
tensor_expr = call.params[1]
|
|
101
|
+
kwargs = (
|
|
102
|
+
{
|
|
103
|
+
kw.key._sym_name: kw.value
|
|
104
|
+
for kw in call.params[2:]
|
|
105
|
+
if isinstance(kw, uni.KWPair)
|
|
106
|
+
}
|
|
107
|
+
if len(call.params) > 2
|
|
108
|
+
else {}
|
|
109
|
+
)
|
|
110
|
+
return (call.target, name, tensor_expr, kwargs)
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
def exit_if_stmt(self, node: uni.IfStmt) -> None:
|
|
114
|
+
"""Exit if statement."""
|
|
115
|
+
a0 = node.body[0]
|
|
116
|
+
new_node = None
|
|
117
|
+
if node.else_body:
|
|
118
|
+
b0 = node.else_body.body[0]
|
|
119
|
+
else:
|
|
120
|
+
return
|
|
121
|
+
if isinstance(a0, uni.Assignment) and isinstance(b0, uni.Assignment):
|
|
122
|
+
lhs = self.check_same_lhs(a0, b0)
|
|
123
|
+
if lhs is not None:
|
|
124
|
+
func_name = self.gen_name(node, Tok.NAME, "torch")
|
|
125
|
+
attr_name = self.gen_name(node, Tok.NAME, "where")
|
|
126
|
+
target = uni.AtomTrailer(
|
|
127
|
+
target=func_name,
|
|
128
|
+
right=attr_name,
|
|
129
|
+
is_attr=True,
|
|
130
|
+
is_null_ok=False,
|
|
131
|
+
kid=[func_name, attr_name],
|
|
132
|
+
)
|
|
133
|
+
call = uni.FuncCall(
|
|
134
|
+
target=target,
|
|
135
|
+
params=[
|
|
136
|
+
node.condition,
|
|
137
|
+
cast(uni.Expr, a0.value),
|
|
138
|
+
cast(uni.Expr, b0.value),
|
|
139
|
+
],
|
|
140
|
+
genai_call=None,
|
|
141
|
+
kid=[target, node.condition, a0, b0],
|
|
142
|
+
)
|
|
143
|
+
new_node = uni.Assignment(
|
|
144
|
+
target=[lhs], value=call, type_tag=None, kid=[lhs, call]
|
|
145
|
+
)
|
|
146
|
+
self.replace_node(new_node, node, "body")
|
|
147
|
+
|
|
148
|
+
elif isinstance(a0, uni.ReturnStmt) and isinstance(b0, uni.ReturnStmt):
|
|
149
|
+
aexpr, bexpr = a0.expr, b0.expr
|
|
150
|
+
if aexpr is None or bexpr is None:
|
|
151
|
+
return
|
|
152
|
+
func_name = self.gen_name(node, Tok.NAME, "torch")
|
|
153
|
+
attr_name = self.gen_name(node, Tok.NAME, "where")
|
|
154
|
+
target = uni.AtomTrailer(
|
|
155
|
+
target=func_name,
|
|
156
|
+
right=attr_name,
|
|
157
|
+
is_attr=True,
|
|
158
|
+
is_null_ok=False,
|
|
159
|
+
kid=[func_name, attr_name],
|
|
160
|
+
)
|
|
161
|
+
call = uni.FuncCall(
|
|
162
|
+
target=target,
|
|
163
|
+
params=[node.condition, cast(uni.Expr, aexpr), cast(uni.Expr, bexpr)],
|
|
164
|
+
genai_call=None,
|
|
165
|
+
kid=[target, node.condition, a0, b0],
|
|
166
|
+
)
|
|
167
|
+
new_node = uni.ReturnStmt(expr=call, kid=[call])
|
|
168
|
+
self.replace_node(new_node, node, "body")
|
|
169
|
+
|
|
170
|
+
elif isinstance(a0, uni.ExprStmt) and isinstance(b0, uni.ExprStmt):
|
|
171
|
+
a_reg = self.check_call(a0)
|
|
172
|
+
b_reg = self.check_call(b0)
|
|
173
|
+
if a_reg is not None and b_reg is not None:
|
|
174
|
+
a_target, a_name, a_expr, a_kwargs = a_reg
|
|
175
|
+
b_target, b_name, b_expr, b_kwargs = b_reg
|
|
176
|
+
if a_name.value == b_name.value and set(a_kwargs.keys()) == set(
|
|
177
|
+
b_kwargs.keys()
|
|
178
|
+
):
|
|
179
|
+
tmp_name = self.gen_name(node, Tok.NAME, f"__{eval(a_name.value)}")
|
|
180
|
+
tmp_name.py_ctx_func = ast3.Store
|
|
181
|
+
func_name = self.gen_name(node, Tok.NAME, "torch")
|
|
182
|
+
attr_name = self.gen_name(node, Tok.NAME, "where")
|
|
183
|
+
target = uni.AtomTrailer(
|
|
184
|
+
target=func_name,
|
|
185
|
+
right=attr_name,
|
|
186
|
+
is_attr=True,
|
|
187
|
+
is_null_ok=False,
|
|
188
|
+
kid=[func_name, attr_name],
|
|
189
|
+
)
|
|
190
|
+
call = uni.FuncCall(
|
|
191
|
+
target=target,
|
|
192
|
+
params=[node.condition, a_expr, b_expr],
|
|
193
|
+
genai_call=None,
|
|
194
|
+
kid=[target, node.condition, a_expr, b_expr],
|
|
195
|
+
)
|
|
196
|
+
assign_node = uni.Assignment(
|
|
197
|
+
target=[tmp_name],
|
|
198
|
+
value=call,
|
|
199
|
+
type_tag=None,
|
|
200
|
+
kid=[tmp_name, call],
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
kwargs_nodes = [
|
|
204
|
+
uni.KWPair(
|
|
205
|
+
name := self.gen_name(node, Tok.NAME, k), v, [name, v]
|
|
206
|
+
)
|
|
207
|
+
for k, v in a_kwargs.items()
|
|
208
|
+
]
|
|
209
|
+
param_name = self.gen_name(
|
|
210
|
+
node, Tok.NAME, f"__{eval(a_name.value)}"
|
|
211
|
+
)
|
|
212
|
+
reg_call = uni.FuncCall(
|
|
213
|
+
target=a_target,
|
|
214
|
+
params=[a_name, param_name] + kwargs_nodes,
|
|
215
|
+
genai_call=None,
|
|
216
|
+
kid=[a_target, a_name, param_name] + kwargs_nodes,
|
|
217
|
+
)
|
|
218
|
+
reg_node = uni.ExprStmt(
|
|
219
|
+
expr=reg_call, in_fstring=False, kid=[reg_call]
|
|
220
|
+
)
|
|
221
|
+
self.replace_node([assign_node, reg_node], node, "body")
|