jaclang 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

Files changed (103) hide show
  1. jaclang/cli/cli.md +3 -3
  2. jaclang/cli/cli.py +37 -37
  3. jaclang/cli/cmdreg.py +45 -140
  4. jaclang/compiler/constant.py +0 -1
  5. jaclang/compiler/jac.lark +3 -6
  6. jaclang/compiler/larkparse/jac_parser.py +2 -2
  7. jaclang/compiler/parser.py +213 -34
  8. jaclang/compiler/passes/main/__init__.py +2 -4
  9. jaclang/compiler/passes/main/def_use_pass.py +0 -4
  10. jaclang/compiler/passes/main/predynamo_pass.py +221 -0
  11. jaclang/compiler/passes/main/pyast_gen_pass.py +83 -55
  12. jaclang/compiler/passes/main/pyast_load_pass.py +66 -40
  13. jaclang/compiler/passes/main/sym_tab_build_pass.py +1 -1
  14. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym.jac +2 -0
  15. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym_test.jac +6 -0
  16. jaclang/compiler/passes/main/tests/fixtures/checker/imported_sym.jac +5 -0
  17. jaclang/compiler/passes/main/tests/fixtures/checker_arg_param_match.jac +37 -0
  18. jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +18 -0
  19. jaclang/compiler/passes/main/tests/fixtures/checker_binary_op.jac +21 -0
  20. jaclang/compiler/passes/main/tests/fixtures/checker_call_expr_class.jac +12 -0
  21. jaclang/compiler/passes/main/tests/fixtures/checker_cat_is_animal.jac +18 -0
  22. jaclang/compiler/passes/main/tests/fixtures/checker_cyclic_symbol.jac +4 -0
  23. jaclang/compiler/passes/main/tests/fixtures/checker_expr_call.jac +9 -0
  24. jaclang/compiler/passes/main/tests/fixtures/checker_float.jac +7 -0
  25. jaclang/compiler/passes/main/tests/fixtures/checker_import_missing_module.jac +13 -0
  26. jaclang/compiler/passes/main/tests/fixtures/checker_magic_call.jac +17 -0
  27. jaclang/compiler/passes/main/tests/fixtures/checker_mod_path.jac +8 -0
  28. jaclang/compiler/passes/main/tests/fixtures/checker_param_types.jac +11 -0
  29. jaclang/compiler/passes/main/tests/fixtures/checker_self_type.jac +9 -0
  30. jaclang/compiler/passes/main/tests/fixtures/checker_sym_inherit.jac +42 -0
  31. jaclang/compiler/passes/main/tests/fixtures/predynamo_fix3.jac +43 -0
  32. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_assign.jac +13 -0
  33. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_return.jac +11 -0
  34. jaclang/compiler/passes/main/tests/test_checker_pass.py +265 -0
  35. jaclang/compiler/passes/main/tests/test_predynamo_pass.py +57 -0
  36. jaclang/compiler/passes/main/type_checker_pass.py +36 -61
  37. jaclang/compiler/passes/tool/doc_ir_gen_pass.py +204 -44
  38. jaclang/compiler/passes/tool/jac_formatter_pass.py +119 -69
  39. jaclang/compiler/passes/tool/tests/fixtures/corelib_fmt.jac +3 -3
  40. jaclang/compiler/passes/tool/tests/fixtures/general_format_checks/triple_quoted_string.jac +4 -5
  41. jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +171 -11
  42. jaclang/compiler/passes/transform.py +12 -8
  43. jaclang/compiler/program.py +14 -6
  44. jaclang/compiler/tests/fixtures/jac_import_py_files.py +4 -0
  45. jaclang/compiler/tests/fixtures/jac_module.jac +3 -0
  46. jaclang/compiler/tests/fixtures/multiple_syntax_errors.jac +10 -0
  47. jaclang/compiler/tests/fixtures/python_module.py +1 -0
  48. jaclang/compiler/tests/test_importer.py +39 -0
  49. jaclang/compiler/tests/test_parser.py +49 -0
  50. jaclang/compiler/type_system/operations.py +104 -0
  51. jaclang/compiler/type_system/type_evaluator.py +470 -47
  52. jaclang/compiler/type_system/type_utils.py +246 -0
  53. jaclang/compiler/type_system/types.py +58 -2
  54. jaclang/compiler/unitree.py +79 -94
  55. jaclang/langserve/engine.jac +253 -230
  56. jaclang/langserve/server.jac +46 -15
  57. jaclang/langserve/tests/fixtures/circle.jac +3 -3
  58. jaclang/langserve/tests/fixtures/circle_err.jac +3 -3
  59. jaclang/langserve/tests/fixtures/circle_pure.test.jac +3 -3
  60. jaclang/langserve/tests/fixtures/completion_test_err.jac +10 -0
  61. jaclang/langserve/tests/server_test/circle_template.jac +80 -0
  62. jaclang/langserve/tests/server_test/glob_template.jac +4 -0
  63. jaclang/langserve/tests/server_test/test_lang_serve.py +154 -312
  64. jaclang/langserve/tests/server_test/utils.py +153 -116
  65. jaclang/langserve/tests/test_dev_server.py +1 -1
  66. jaclang/langserve/tests/test_server.py +30 -86
  67. jaclang/langserve/utils.jac +56 -63
  68. jaclang/runtimelib/machine.py +7 -0
  69. jaclang/runtimelib/meta_importer.py +27 -1
  70. jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
  71. jaclang/runtimelib/tests/fixtures/savable_object.jac +2 -2
  72. jaclang/settings.py +18 -14
  73. jaclang/tests/fixtures/abc_check.jac +3 -3
  74. jaclang/tests/fixtures/arch_rel_import_creation.jac +12 -12
  75. jaclang/tests/fixtures/chandra_bugs2.jac +3 -3
  76. jaclang/tests/fixtures/create_dynamic_archetype.jac +13 -13
  77. jaclang/tests/fixtures/jac_run_py_bugs.py +18 -0
  78. jaclang/tests/fixtures/jac_run_py_import.py +13 -0
  79. jaclang/tests/fixtures/lambda_arg_annotation.jac +15 -0
  80. jaclang/tests/fixtures/lambda_self.jac +18 -0
  81. jaclang/tests/fixtures/maxfail_run_test.jac +4 -4
  82. jaclang/tests/fixtures/params/param_syntax_err.jac +9 -0
  83. jaclang/tests/fixtures/params/test_complex_params.jac +42 -0
  84. jaclang/tests/fixtures/params/test_failing_kwonly.jac +207 -0
  85. jaclang/tests/fixtures/params/test_failing_posonly.jac +116 -0
  86. jaclang/tests/fixtures/params/test_failing_varargs.jac +300 -0
  87. jaclang/tests/fixtures/params/test_kwonly_params.jac +29 -0
  88. jaclang/tests/fixtures/py2jac_params.py +8 -0
  89. jaclang/tests/fixtures/run_test.jac +4 -4
  90. jaclang/tests/test_cli.py +103 -18
  91. jaclang/tests/test_language.py +74 -16
  92. jaclang/utils/helpers.py +47 -2
  93. jaclang/utils/module_resolver.py +11 -1
  94. jaclang/utils/test.py +8 -0
  95. jaclang/utils/treeprinter.py +0 -18
  96. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/METADATA +3 -3
  97. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/RECORD +99 -62
  98. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/WHEEL +1 -1
  99. jaclang/compiler/passes/main/inheritance_pass.py +0 -131
  100. jaclang/langserve/dev_engine.jac +0 -645
  101. jaclang/langserve/dev_server.jac +0 -201
  102. jaclang/langserve/tests/server_test/code_test.py +0 -0
  103. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/entry_points.txt +0 -0
@@ -5,12 +5,14 @@ from __future__ import annotations
5
5
  import keyword
6
6
  import logging
7
7
  import os
8
+ import sys
8
9
  from typing import Callable, Sequence, TYPE_CHECKING, TypeAlias, TypeVar, cast
9
10
 
10
11
  import jaclang.compiler.unitree as uni
11
- from jaclang.compiler import jac_lark as jl
12
+ from jaclang.compiler import TOKEN_MAP, jac_lark as jl
12
13
  from jaclang.compiler.constant import EdgeDir, Tokens as Tok
13
14
  from jaclang.compiler.passes.main import Transform
15
+ from jaclang.utils.helpers import ANSIColors
14
16
  from jaclang.vendor.lark import Lark, Transformer, Tree, logger
15
17
 
16
18
  if TYPE_CHECKING:
@@ -39,22 +41,17 @@ class JacParser(Transform[uni.Source, uni.Module]):
39
41
  tree, comments = JacParser.parse(ir_in.value, on_error=self.error_callback)
40
42
  mod = JacParser.TreeToAST(parser=self).transform(tree)
41
43
  ir_in.comments = [self.proc_comment(i, mod) for i in comments]
42
- if isinstance(mod, uni.Module):
43
- self.ir_out = mod
44
- return mod
45
- else:
44
+ if not isinstance(mod, uni.Module):
46
45
  raise self.ice()
46
+ if len(self.errors_had) != 0:
47
+ mod.has_syntax_errors = True
48
+ self.report_errors()
49
+ self.ir_out = mod
50
+ return mod
47
51
  except jl.UnexpectedInput as e:
48
- catch_error = uni.EmptyToken()
49
- catch_error.orig_src = ir_in
50
- catch_error.line_no = e.line
51
- catch_error.end_line = e.line
52
- catch_error.c_start = e.column
53
- catch_error.c_end = e.column + 1
54
- catch_error.pos_start = e.pos_in_stream or 0
55
- catch_error.pos_end = catch_error.pos_start + 1
52
+ catch_error = self.error_to_token(e)
53
+ error_msg = self.error_to_message(e)
56
54
 
57
- error_msg = "Syntax Error"
58
55
  if len(e.args) >= 1 and isinstance(e.args[0], str):
59
56
  error_msg += e.args[0]
60
57
  self.log_error(error_msg, node_override=catch_error)
@@ -62,7 +59,12 @@ class JacParser(Transform[uni.Source, uni.Module]):
62
59
  except Exception as e:
63
60
  raise e
64
61
 
65
- return uni.Module.make_stub(inject_src=ir_in)
62
+ # If we reach here, there was a syntax error, mark the module as such
63
+ # and report errors.
64
+ self.report_errors()
65
+ mod = uni.Module.make_stub(inject_src=ir_in)
66
+ mod.has_syntax_errors = True
67
+ return mod
66
68
 
67
69
  @staticmethod
68
70
  def proc_comment(token: jl.Token, mod: uni.UniNode) -> uni.CommentToken:
@@ -80,10 +82,109 @@ class JacParser(Transform[uni.Source, uni.Module]):
80
82
  kid=[],
81
83
  )
82
84
 
85
+ _MISSING_TOKENS = [
86
+ Tok.SEMI,
87
+ Tok.COMMA,
88
+ Tok.COLON,
89
+ Tok.RPAREN,
90
+ Tok.RBRACE,
91
+ Tok.RSQUARE,
92
+ Tok.RETURN_HINT,
93
+ ]
94
+
83
95
  def error_callback(self, e: jl.UnexpectedInput) -> bool:
84
96
  """Handle error."""
97
+ iparser = e.interactive_parser
98
+
99
+ def try_feed_missing_token(iparser: jl.InteractiveParser) -> Tok | None:
100
+ """Feed a missing token to the parser."""
101
+ # If any of the below token is missing, insert them and continue parsing.
102
+ accepts = iparser.accepts()
103
+ for tok in JacParser._MISSING_TOKENS:
104
+ if tok.name in accepts:
105
+ iparser.feed_token(jl.Token(tok.name, TOKEN_MAP[tok.name]))
106
+ return tok
107
+ return None
108
+
109
+ def feed_current_token(iparser: jl.InteractiveParser, tok: jl.Token) -> bool:
110
+ """Feed the current token to the parser."""
111
+ while tok.type not in iparser.accepts():
112
+ if not try_feed_missing_token(iparser):
113
+ return False
114
+ iparser.feed_token(tok)
115
+ return True
116
+
117
+ if isinstance(e, jl.UnexpectedToken):
118
+ # If last token is DOT and we expect a NAME, insert a NAME token
119
+ last_tok: jl.Token | None = (
120
+ e.token_history[-1]
121
+ if e.token_history and len(e.token_history) >= 1
122
+ else None
123
+ )
124
+ if (
125
+ last_tok
126
+ and last_tok.type == Tok.DOT.name
127
+ and (Tok.NAME.name in e.accepts)
128
+ ):
129
+ self.log_error("Incomplete member access", self.error_to_token(e))
130
+ iparser.feed_token(jl.Token(Tok.NAME.name, "recover_name_token"))
131
+ return feed_current_token(iparser, e.token)
132
+
133
+ # We're calling try_feed_missing_token twice here because the first missing
134
+ # will be reported as such and we don't for the consequent missing token.
135
+ if tk := try_feed_missing_token(iparser):
136
+ self.log_error(f"Missing {tk.name}", self.error_to_token(e))
137
+ return feed_current_token(iparser, e.token)
138
+
139
+ # Ignore unexpected tokens and continue parsing till we reach a known state.
140
+ self.log_error(
141
+ f"Unexpected token '{e.token.value}'", self.error_to_token(e)
142
+ )
143
+ return True
144
+
85
145
  return False
86
146
 
147
+ def error_to_message(self, e: jl.UnexpectedInput) -> str:
148
+ """Return an error message based on the exception."""
149
+ # TODO: Match more specific errors with lark's example based matching.
150
+ # Reference: https://github.com/lark-parser/lark/blob/master/examples/advanced/error_reporting_lalr.py
151
+ # e.match_examples()
152
+ if isinstance(e, jl.UnexpectedToken):
153
+ return f"Unexpected token '{e.token.value}'"
154
+ return "Syntax Error"
155
+
156
+ def error_to_token(self, e: jl.UnexpectedInput) -> uni.Token:
157
+ """Convert error to token."""
158
+ catch_error = uni.EmptyToken()
159
+ catch_error.orig_src = self.ir_in
160
+ catch_error.line_no = e.line
161
+ catch_error.end_line = e.line
162
+ catch_error.c_start = e.column
163
+ catch_error.pos_start = e.pos_in_stream or 0
164
+ if isinstance(e, jl.UnexpectedToken) and e.token:
165
+ catch_error.c_end = e.token.end_column or (e.column + 1)
166
+ catch_error.pos_end = e.token.end_pos or (catch_error.pos_start + 1)
167
+ else:
168
+ catch_error.c_end = e.column + 1
169
+ catch_error.pos_end = catch_error.pos_start + 1
170
+ return catch_error
171
+
172
+ def report_errors(self, *, colors: bool = True) -> None:
173
+ """Report errors to the user."""
174
+ # TODO: Write a better IO system.
175
+ # NOTE: Currently it writes all the errors to stderr cause LSP JsonRPC uses stdout for IPC.
176
+ if not sys.stderr.isatty():
177
+ # FIXME: If we're outputting to a file (pipe, redirection, etc) other
178
+ # than a terminal we disable colors however we should be able to force
179
+ # colors with a configuration.
180
+ colors = False
181
+ for alrt in self.errors_had:
182
+ error_label = (
183
+ "Error:" if not colors else f"{ANSIColors.RED}Error:{ANSIColors.END}"
184
+ )
185
+ print(error_label, end=" ", file=sys.stderr)
186
+ print(alrt.pretty_print(colors=colors), file=sys.stderr)
187
+
87
188
  @staticmethod
88
189
  def _comment_callback(comment: jl.Token) -> None:
89
190
  JacParser.comment_cache.append(comment)
@@ -848,25 +949,104 @@ class JacParser(Transform[uni.Source, uni.Module]):
848
949
  if self.match_token(Tok.RETURN_HINT):
849
950
  return_spec = self.consume(uni.Expr)
850
951
  return uni.FuncSignature(
952
+ posonly_params=[],
851
953
  params=[],
954
+ varargs=None,
955
+ kwonlyargs=[],
956
+ kwargs=None,
852
957
  return_type=return_spec,
853
958
  kid=self.flat_cur_nodes,
854
959
  )
855
960
  # Otherwise, parse the traditional parameter list form
856
961
  else:
857
962
  self.consume_token(Tok.LPAREN)
858
- params = self.match(list)
963
+ all_params = self.match(list) or []
964
+ posonly_params, params, varargs, kwonlyargs, kwargs = (
965
+ self._parse_parameter_categories(all_params)
966
+ )
859
967
  self.consume_token(Tok.RPAREN)
860
968
  if self.match_token(Tok.RETURN_HINT):
861
969
  return_spec = self.consume(uni.Expr)
862
970
  return uni.FuncSignature(
863
- params=(
864
- self.extract_from_list(params, uni.ParamVar) if params else []
865
- ),
971
+ posonly_params=posonly_params,
972
+ params=params,
973
+ varargs=varargs,
974
+ kwonlyargs=kwonlyargs,
975
+ kwargs=kwargs,
866
976
  return_type=return_spec,
867
977
  kid=self.flat_cur_nodes,
868
978
  )
869
979
 
980
+ def _parse_parameter_categories(self, all_params: list[uni.UniNode]) -> tuple[
981
+ list[uni.ParamVar],
982
+ list[uni.ParamVar],
983
+ uni.ParamVar | None,
984
+ list[uni.ParamVar],
985
+ uni.ParamVar | None,
986
+ ]:
987
+ posonly_params = []
988
+ params = []
989
+ varargs = None
990
+ kwonlyargs = []
991
+ kwargs = None
992
+
993
+ # Initial state determination
994
+ cur_state = "positional"
995
+ for param in all_params:
996
+ if isinstance(param, uni.Token) and param.name == Tok.DIV:
997
+ cur_state = "posonly"
998
+ break
999
+
1000
+ for cur_nd in all_params:
1001
+ cur_state = self._update_parameter_state(cur_nd, cur_state)
1002
+ if isinstance(cur_nd, uni.ParamVar):
1003
+ if cur_state == "positional":
1004
+ cur_nd.param_kind = uni.ParamKind.NORMAL
1005
+ params.append(cur_nd)
1006
+ elif cur_state == "posonly":
1007
+ cur_nd.param_kind = uni.ParamKind.POSONLY
1008
+ posonly_params.append(cur_nd)
1009
+ elif cur_state == "varargs":
1010
+ cur_nd.param_kind = uni.ParamKind.VARARG
1011
+ varargs = cur_nd
1012
+ cur_state = "keyword_only"
1013
+ elif cur_state == "keyword_only":
1014
+ cur_nd.param_kind = uni.ParamKind.KWONLY
1015
+ kwonlyargs.append(cur_nd)
1016
+ elif cur_state == "kwargs":
1017
+ cur_nd.param_kind = uni.ParamKind.KWARG
1018
+ kwargs = cur_nd
1019
+ else:
1020
+ raise self.ice()
1021
+
1022
+ return posonly_params, params, varargs, kwonlyargs, kwargs
1023
+
1024
+ def _update_parameter_state(self, cur_nd: uni.UniNode, cur_state: str) -> str:
1025
+ if isinstance(cur_nd, uni.Token):
1026
+ if cur_nd.name == Tok.DIV:
1027
+ if cur_state in ["keyword_only", "kwargs", "positional"]:
1028
+ self.parse_ref.log_error(
1029
+ "Invalid syntax in function parameters: '/' cannot appear after '*' or '**'.",
1030
+ node_override=cur_nd,
1031
+ )
1032
+ return "positional"
1033
+ elif cur_nd.name == Tok.STAR_MUL:
1034
+ if cur_state in ["keyword_only", "kwargs"]:
1035
+ self.parse_ref.log_error(
1036
+ "Invalid syntax in function parameters: '*' cannot appear after '**'.",
1037
+ node_override=cur_nd,
1038
+ )
1039
+ return "keyword_only"
1040
+ elif cur_nd.name == Tok.COMMA:
1041
+ return cur_state
1042
+
1043
+ elif isinstance(cur_nd, uni.ParamVar):
1044
+ if cur_nd.is_vararg:
1045
+ return "varargs"
1046
+ if cur_nd.is_kwargs:
1047
+ return "kwargs"
1048
+ return cur_state
1049
+
870
1050
  def func_decl_params(self, _: None) -> list[uni.UniNode]:
871
1051
  """Grammar rule.
872
1052
 
@@ -874,11 +1054,17 @@ class JacParser(Transform[uni.Source, uni.Module]):
874
1054
  """
875
1055
  return self.cur_nodes
876
1056
 
877
- def param_var(self, _: None) -> uni.ParamVar:
1057
+ def param_var(self, _: None) -> uni.ParamVar | uni.Token:
878
1058
  """Grammar rule.
879
1059
 
880
- param_var: (STAR_POW | STAR_MUL)? NAME type_tag (EQ expression)?
1060
+ param_var: (STAR_POW | STAR_MUL)? named_ref type_tag (EQ expression)?
1061
+ | DIV
1062
+ | STAR_MUL
881
1063
  """
1064
+ if len(self.cur_nodes) == 1 and (
1065
+ star_only := self.match_token(Tok.DIV) or self.match_token(Tok.STAR_MUL)
1066
+ ):
1067
+ return star_only
882
1068
  star = self.match_token(Tok.STAR_POW) or self.match_token(Tok.STAR_MUL)
883
1069
  name = self.consume(uni.Name)
884
1070
  type_tag = self.consume(uni.SubTag)
@@ -1028,7 +1214,6 @@ class JacParser(Transform[uni.Source, uni.Module]):
1028
1214
  | (yield_expr | KW_YIELD) SEMI
1029
1215
  | raise_stmt SEMI
1030
1216
  | assert_stmt SEMI
1031
- | check_stmt SEMI
1032
1217
  | assignment SEMI
1033
1218
  | delete_stmt SEMI
1034
1219
  | report_stmt SEMI
@@ -1307,18 +1492,6 @@ class JacParser(Transform[uni.Source, uni.Module]):
1307
1492
  kid=self.cur_nodes,
1308
1493
  )
1309
1494
 
1310
- def check_stmt(self, _: None) -> uni.CheckStmt:
1311
- """Grammar rule.
1312
-
1313
- check_stmt: KW_CHECK expression
1314
- """
1315
- self.consume_token(Tok.KW_CHECK)
1316
- target = self.consume(uni.Expr)
1317
- return uni.CheckStmt(
1318
- target=target,
1319
- kid=self.cur_nodes,
1320
- )
1321
-
1322
1495
  def ctrl_stmt(self, _: None) -> uni.CtrlStmt | uni.DisengageStmt:
1323
1496
  """Grammar rule.
1324
1497
 
@@ -1545,9 +1718,13 @@ class JacParser(Transform[uni.Source, uni.Module]):
1545
1718
  sig_kid.append(return_type)
1546
1719
  signature = (
1547
1720
  uni.FuncSignature(
1721
+ posonly_params=[],
1548
1722
  params=(
1549
1723
  self.extract_from_list(params, uni.ParamVar) if params else []
1550
1724
  ),
1725
+ varargs=None,
1726
+ kwonlyargs=[],
1727
+ kwargs=None,
1551
1728
  return_type=return_type,
1552
1729
  kid=sig_kid,
1553
1730
  )
@@ -2328,6 +2505,7 @@ class JacParser(Transform[uni.Source, uni.Module]):
2328
2505
  LSQUARE (KW_NODE| KW_EDGE)? expression? (edge_op_ref (filter_compr | expression)?)+ RSQUARE
2329
2506
  """
2330
2507
  self.consume_token(Tok.LSQUARE)
2508
+ is_async = bool(self.match_token(Tok.KW_ASYNC))
2331
2509
  edges_only = bool(self.match_token(Tok.KW_EDGE))
2332
2510
  self.match_token(Tok.KW_NODE)
2333
2511
  valid_chain = []
@@ -2338,6 +2516,7 @@ class JacParser(Transform[uni.Source, uni.Module]):
2338
2516
  return uni.EdgeRefTrailer(
2339
2517
  chain=valid_chain,
2340
2518
  edges_only=edges_only,
2519
+ is_async=is_async,
2341
2520
  kid=self.cur_nodes,
2342
2521
  )
2343
2522
 
@@ -4,7 +4,6 @@ from ..transform import Alert, Transform # noqa: I100
4
4
  from .annex_pass import JacAnnexPass # noqa: I100
5
5
  from .binder_pass import BinderPass # noqa: I100
6
6
  from .sym_tab_build_pass import SymTabBuildPass, UniPass # noqa: I100
7
- from .sym_tab_link_pass import SymTabLinkPass # noqa: I100
8
7
  from .def_use_pass import DefUsePass # noqa: I100
9
8
  from .sem_def_match_pass import SemDefMatchPass # noqa: I100
10
9
  from .import_pass import JacImportDepsPass # noqa: I100
@@ -12,10 +11,10 @@ from .def_impl_match_pass import DeclImplMatchPass # noqa: I100
12
11
  from .type_checker_pass import TypeCheckPass # noqa: I100
13
12
  from .pyast_load_pass import PyastBuildPass # type: ignore # noqa: I100
14
13
  from .pyast_gen_pass import PyastGenPass # noqa: I100
14
+ from .predynamo_pass import PreDynamoPass # noqa: I100
15
15
  from .pybc_gen_pass import PyBytecodeGenPass # noqa: I100
16
16
  from .cfg_build_pass import CFGBuildPass # noqa: I100
17
17
  from .pyjac_ast_link_pass import PyJacAstLinkPass # noqa: I100
18
- from .inheritance_pass import InheritancePass # noqa: I100
19
18
 
20
19
 
21
20
  __all__ = [
@@ -28,14 +27,13 @@ __all__ = [
28
27
  "BinderPass",
29
28
  "TypeCheckPass",
30
29
  "SymTabBuildPass",
31
- "SymTabLinkPass",
32
30
  "DeclImplMatchPass",
33
31
  "DefUsePass",
34
32
  "SemDefMatchPass",
35
33
  "PyastBuildPass",
36
34
  "PyastGenPass",
35
+ "PreDynamoPass",
37
36
  "PyBytecodeGenPass",
38
37
  "CFGBuildPass",
39
38
  "PyJacAstLinkPass",
40
- "InheritancePass",
41
39
  ]
@@ -30,7 +30,6 @@ class DefUsePass(UniPass):
30
30
  """Jac Ast build pass."""
31
31
 
32
32
  def enter_archetype(self, node: uni.Archetype) -> None:
33
- node.sym_tab.inherit_baseclasses_sym(node)
34
33
 
35
34
  def inform_from_walker(node: uni.UniNode) -> None:
36
35
  for i in (
@@ -47,9 +46,6 @@ class DefUsePass(UniPass):
47
46
  if isinstance(i.body, uni.ImplDef):
48
47
  inform_from_walker(i.body)
49
48
 
50
- def enter_enum(self, node: uni.Enum) -> None:
51
- node.sym_tab.inherit_baseclasses_sym(node)
52
-
53
49
  def enter_type_ref(self, node: uni.TypeRef) -> None:
54
50
  node.sym_tab.use_lookup(node)
55
51
 
@@ -0,0 +1,221 @@
1
+ """Pytorch Fix Pass."""
2
+
3
+ import ast as ast3
4
+ from typing import Optional, TypeVar, cast
5
+
6
+ import jaclang.compiler.unitree as uni
7
+ from jaclang.compiler.constant import Tokens as Tok
8
+ from jaclang.compiler.passes import UniPass
9
+
10
+
11
+ T = TypeVar("T", bound=ast3.AST)
12
+
13
+
14
+ class PreDynamoPass(UniPass):
15
+ """Pre-Dynamo pass for PyTorch."""
16
+
17
+ def enter_node(self, node: uni.UniNode) -> None:
18
+ """Enter node."""
19
+ super().enter_node(node)
20
+
21
+ def exit_node(self, node: uni.UniNode) -> None:
22
+ """Exit node."""
23
+ super().exit_node(node)
24
+
25
+ def gen_name(self, node: uni.UniNode, name: Tok, value: str) -> uni.Name:
26
+ """Generate Name."""
27
+ return uni.Name(
28
+ name=name,
29
+ value=value,
30
+ orig_src=node.loc.orig_src,
31
+ col_start=node.loc.col_start,
32
+ col_end=0,
33
+ line=node.loc.first_line,
34
+ end_line=node.loc.last_line,
35
+ pos_start=0,
36
+ pos_end=0,
37
+ )
38
+
39
+ def replace_node(
40
+ self,
41
+ new_nodes: list[uni.UniNode] | uni.UniNode,
42
+ old_node: uni.UniNode,
43
+ attr: str,
44
+ ) -> None:
45
+ """Replace old node with new nodes in parent's body and kid lists."""
46
+ parent = old_node.parent
47
+ if isinstance(new_nodes, uni.UniNode):
48
+ new_nodes.parent = parent
49
+ if hasattr(parent, attr):
50
+ lst = getattr(parent, attr)
51
+ if old_node in lst:
52
+ idx = lst.index(old_node)
53
+ lst[idx] = new_nodes
54
+ if hasattr(parent, "kid") and old_node in parent.kid:
55
+ idx = parent.kid.index(old_node)
56
+ parent.kid[idx] = new_nodes
57
+ else: # list of nodes
58
+ for n in new_nodes:
59
+ n.parent = parent
60
+ if hasattr(parent, attr):
61
+ lst = getattr(parent, attr)
62
+ if old_node in lst:
63
+ idx = lst.index(old_node)
64
+ setattr(parent, attr, lst[:idx] + new_nodes + lst[idx + 1 :])
65
+ if hasattr(parent, "kid") and old_node in parent.kid:
66
+ idx = parent.kid.index(old_node)
67
+ parent.kid = parent.kid[:idx] + new_nodes + parent.kid[idx + 1 :]
68
+
69
+ def check_same_lhs(
70
+ self, assign_a: uni.UniNode, assign_b: uni.UniNode
71
+ ) -> Optional[uni.Name]:
72
+ """Return the common LHS target if both are simple assignment with same target."""
73
+ if not (
74
+ isinstance(assign_a, uni.Assignment)
75
+ and isinstance(assign_b, uni.Assignment)
76
+ ):
77
+ return None
78
+ ta, tb = assign_a.target[0], assign_b.target[0]
79
+ if not (isinstance(ta, uni.Name) and isinstance(tb, uni.Name)):
80
+ return None
81
+ if ta.value != tb.value:
82
+ return None
83
+ return ta # common target
84
+
85
+ def check_call(self, node: uni.ExprStmt) -> Optional[tuple]:
86
+ """Return (target, name, tensor_expr, kwargs) if node is target(name, tensor_expr, **kwargs)."""
87
+ if isinstance(node, uni.ExprStmt) and isinstance(node.expr, uni.FuncCall):
88
+ call = node.expr
89
+ if (
90
+ isinstance(call.target, uni.AtomTrailer)
91
+ and len(call.params) >= 2
92
+ and isinstance(call.params[0], (uni.String, uni.MultiString))
93
+ and isinstance(call.params[1], uni.Expr)
94
+ ):
95
+ name = (
96
+ call.params[0]
97
+ if isinstance(call.params[0], uni.String)
98
+ else call.params[0].strings[0]
99
+ )
100
+ tensor_expr = call.params[1]
101
+ kwargs = (
102
+ {
103
+ kw.key._sym_name: kw.value
104
+ for kw in call.params[2:]
105
+ if isinstance(kw, uni.KWPair)
106
+ }
107
+ if len(call.params) > 2
108
+ else {}
109
+ )
110
+ return (call.target, name, tensor_expr, kwargs)
111
+ return None
112
+
113
+ def exit_if_stmt(self, node: uni.IfStmt) -> None:
114
+ """Exit if statement."""
115
+ a0 = node.body[0]
116
+ new_node = None
117
+ if node.else_body:
118
+ b0 = node.else_body.body[0]
119
+ else:
120
+ return
121
+ if isinstance(a0, uni.Assignment) and isinstance(b0, uni.Assignment):
122
+ lhs = self.check_same_lhs(a0, b0)
123
+ if lhs is not None:
124
+ func_name = self.gen_name(node, Tok.NAME, "torch")
125
+ attr_name = self.gen_name(node, Tok.NAME, "where")
126
+ target = uni.AtomTrailer(
127
+ target=func_name,
128
+ right=attr_name,
129
+ is_attr=True,
130
+ is_null_ok=False,
131
+ kid=[func_name, attr_name],
132
+ )
133
+ call = uni.FuncCall(
134
+ target=target,
135
+ params=[
136
+ node.condition,
137
+ cast(uni.Expr, a0.value),
138
+ cast(uni.Expr, b0.value),
139
+ ],
140
+ genai_call=None,
141
+ kid=[target, node.condition, a0, b0],
142
+ )
143
+ new_node = uni.Assignment(
144
+ target=[lhs], value=call, type_tag=None, kid=[lhs, call]
145
+ )
146
+ self.replace_node(new_node, node, "body")
147
+
148
+ elif isinstance(a0, uni.ReturnStmt) and isinstance(b0, uni.ReturnStmt):
149
+ aexpr, bexpr = a0.expr, b0.expr
150
+ if aexpr is None or bexpr is None:
151
+ return
152
+ func_name = self.gen_name(node, Tok.NAME, "torch")
153
+ attr_name = self.gen_name(node, Tok.NAME, "where")
154
+ target = uni.AtomTrailer(
155
+ target=func_name,
156
+ right=attr_name,
157
+ is_attr=True,
158
+ is_null_ok=False,
159
+ kid=[func_name, attr_name],
160
+ )
161
+ call = uni.FuncCall(
162
+ target=target,
163
+ params=[node.condition, cast(uni.Expr, aexpr), cast(uni.Expr, bexpr)],
164
+ genai_call=None,
165
+ kid=[target, node.condition, a0, b0],
166
+ )
167
+ new_node = uni.ReturnStmt(expr=call, kid=[call])
168
+ self.replace_node(new_node, node, "body")
169
+
170
+ elif isinstance(a0, uni.ExprStmt) and isinstance(b0, uni.ExprStmt):
171
+ a_reg = self.check_call(a0)
172
+ b_reg = self.check_call(b0)
173
+ if a_reg is not None and b_reg is not None:
174
+ a_target, a_name, a_expr, a_kwargs = a_reg
175
+ b_target, b_name, b_expr, b_kwargs = b_reg
176
+ if a_name.value == b_name.value and set(a_kwargs.keys()) == set(
177
+ b_kwargs.keys()
178
+ ):
179
+ tmp_name = self.gen_name(node, Tok.NAME, f"__{eval(a_name.value)}")
180
+ tmp_name.py_ctx_func = ast3.Store
181
+ func_name = self.gen_name(node, Tok.NAME, "torch")
182
+ attr_name = self.gen_name(node, Tok.NAME, "where")
183
+ target = uni.AtomTrailer(
184
+ target=func_name,
185
+ right=attr_name,
186
+ is_attr=True,
187
+ is_null_ok=False,
188
+ kid=[func_name, attr_name],
189
+ )
190
+ call = uni.FuncCall(
191
+ target=target,
192
+ params=[node.condition, a_expr, b_expr],
193
+ genai_call=None,
194
+ kid=[target, node.condition, a_expr, b_expr],
195
+ )
196
+ assign_node = uni.Assignment(
197
+ target=[tmp_name],
198
+ value=call,
199
+ type_tag=None,
200
+ kid=[tmp_name, call],
201
+ )
202
+
203
+ kwargs_nodes = [
204
+ uni.KWPair(
205
+ name := self.gen_name(node, Tok.NAME, k), v, [name, v]
206
+ )
207
+ for k, v in a_kwargs.items()
208
+ ]
209
+ param_name = self.gen_name(
210
+ node, Tok.NAME, f"__{eval(a_name.value)}"
211
+ )
212
+ reg_call = uni.FuncCall(
213
+ target=a_target,
214
+ params=[a_name, param_name] + kwargs_nodes,
215
+ genai_call=None,
216
+ kid=[a_target, a_name, param_name] + kwargs_nodes,
217
+ )
218
+ reg_node = uni.ExprStmt(
219
+ expr=reg_call, in_fstring=False, kid=[reg_call]
220
+ )
221
+ self.replace_node([assign_node, reg_node], node, "body")