jaclang 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

Files changed (99) hide show
  1. jaclang/cli/cli.py +77 -29
  2. jaclang/cli/cmdreg.py +44 -0
  3. jaclang/compiler/constant.py +6 -2
  4. jaclang/compiler/jac.lark +37 -47
  5. jaclang/compiler/larkparse/jac_parser.py +2 -2
  6. jaclang/compiler/parser.py +356 -61
  7. jaclang/compiler/passes/main/__init__.py +2 -4
  8. jaclang/compiler/passes/main/def_use_pass.py +1 -4
  9. jaclang/compiler/passes/main/predynamo_pass.py +221 -0
  10. jaclang/compiler/passes/main/pyast_gen_pass.py +221 -135
  11. jaclang/compiler/passes/main/pyast_load_pass.py +54 -20
  12. jaclang/compiler/passes/main/sym_tab_build_pass.py +1 -1
  13. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym.jac +2 -0
  14. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym_test.jac +6 -0
  15. jaclang/compiler/passes/main/tests/fixtures/checker/imported_sym.jac +5 -0
  16. jaclang/compiler/passes/main/tests/fixtures/checker_arg_param_match.jac +37 -0
  17. jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +18 -0
  18. jaclang/compiler/passes/main/tests/fixtures/checker_cat_is_animal.jac +18 -0
  19. jaclang/compiler/passes/main/tests/fixtures/checker_float.jac +7 -0
  20. jaclang/compiler/passes/main/tests/fixtures/checker_param_types.jac +11 -0
  21. jaclang/compiler/passes/main/tests/fixtures/checker_self_type.jac +9 -0
  22. jaclang/compiler/passes/main/tests/fixtures/checker_sym_inherit.jac +42 -0
  23. jaclang/compiler/passes/main/tests/fixtures/predynamo_fix3.jac +43 -0
  24. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_assign.jac +13 -0
  25. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_return.jac +11 -0
  26. jaclang/compiler/passes/main/tests/test_checker_pass.py +190 -0
  27. jaclang/compiler/passes/main/tests/test_predynamo_pass.py +56 -0
  28. jaclang/compiler/passes/main/type_checker_pass.py +29 -73
  29. jaclang/compiler/passes/tool/doc_ir_gen_pass.py +302 -58
  30. jaclang/compiler/passes/tool/jac_formatter_pass.py +119 -69
  31. jaclang/compiler/passes/tool/tests/fixtures/corelib_fmt.jac +3 -3
  32. jaclang/compiler/passes/tool/tests/fixtures/general_format_checks/triple_quoted_string.jac +4 -5
  33. jaclang/compiler/passes/tool/tests/fixtures/import_fmt.jac +7 -1
  34. jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +276 -10
  35. jaclang/compiler/passes/transform.py +12 -8
  36. jaclang/compiler/program.py +19 -7
  37. jaclang/compiler/tests/fixtures/jac_import_py_files.py +4 -0
  38. jaclang/compiler/tests/fixtures/jac_module.jac +3 -0
  39. jaclang/compiler/tests/fixtures/multiple_syntax_errors.jac +10 -0
  40. jaclang/compiler/tests/fixtures/python_module.py +1 -0
  41. jaclang/compiler/tests/test_importer.py +39 -0
  42. jaclang/compiler/tests/test_parser.py +49 -0
  43. jaclang/compiler/type_system/type_evaluator.jac +959 -0
  44. jaclang/compiler/type_system/type_utils.py +246 -0
  45. jaclang/compiler/type_system/types.py +58 -2
  46. jaclang/compiler/unitree.py +102 -107
  47. jaclang/langserve/engine.jac +138 -159
  48. jaclang/langserve/server.jac +25 -1
  49. jaclang/langserve/tests/fixtures/circle.jac +3 -3
  50. jaclang/langserve/tests/fixtures/circle_err.jac +3 -3
  51. jaclang/langserve/tests/fixtures/circle_pure.test.jac +3 -3
  52. jaclang/langserve/tests/fixtures/completion_test_err.jac +10 -0
  53. jaclang/langserve/tests/server_test/circle_template.jac +80 -0
  54. jaclang/langserve/tests/server_test/glob_template.jac +4 -0
  55. jaclang/langserve/tests/server_test/test_lang_serve.py +154 -309
  56. jaclang/langserve/tests/server_test/utils.py +153 -116
  57. jaclang/langserve/tests/test_server.py +21 -84
  58. jaclang/langserve/utils.jac +12 -15
  59. jaclang/lib.py +17 -0
  60. jaclang/runtimelib/archetype.py +25 -25
  61. jaclang/runtimelib/constructs.py +2 -2
  62. jaclang/runtimelib/machine.py +63 -46
  63. jaclang/runtimelib/meta_importer.py +27 -1
  64. jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
  65. jaclang/runtimelib/tests/fixtures/savable_object.jac +2 -2
  66. jaclang/settings.py +19 -16
  67. jaclang/tests/fixtures/abc_check.jac +3 -3
  68. jaclang/tests/fixtures/arch_rel_import_creation.jac +12 -12
  69. jaclang/tests/fixtures/attr_pattern_case.jac +18 -0
  70. jaclang/tests/fixtures/chandra_bugs2.jac +3 -3
  71. jaclang/tests/fixtures/create_dynamic_archetype.jac +13 -13
  72. jaclang/tests/fixtures/funccall_genexpr.jac +7 -0
  73. jaclang/tests/fixtures/funccall_genexpr.py +5 -0
  74. jaclang/tests/fixtures/maxfail_run_test.jac +4 -4
  75. jaclang/tests/fixtures/params/param_syntax_err.jac +9 -0
  76. jaclang/tests/fixtures/params/test_complex_params.jac +42 -0
  77. jaclang/tests/fixtures/params/test_failing_kwonly.jac +207 -0
  78. jaclang/tests/fixtures/params/test_failing_posonly.jac +116 -0
  79. jaclang/tests/fixtures/params/test_failing_varargs.jac +300 -0
  80. jaclang/tests/fixtures/params/test_kwonly_params.jac +29 -0
  81. jaclang/tests/fixtures/py2jac_params.py +8 -0
  82. jaclang/tests/fixtures/run_test.jac +4 -4
  83. jaclang/tests/test_cli.py +159 -7
  84. jaclang/tests/test_language.py +213 -38
  85. jaclang/tests/test_reference.py +3 -1
  86. jaclang/utils/helpers.py +67 -6
  87. jaclang/utils/module_resolver.py +10 -0
  88. jaclang/utils/test.py +8 -0
  89. jaclang/utils/tests/test_lang_tools.py +4 -15
  90. jaclang/utils/treeprinter.py +0 -18
  91. {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/METADATA +1 -2
  92. {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/RECORD +95 -65
  93. {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/WHEEL +1 -1
  94. jaclang/compiler/passes/main/inheritance_pass.py +0 -131
  95. jaclang/compiler/type_system/type_evaluator.py +0 -560
  96. jaclang/langserve/dev_engine.jac +0 -645
  97. jaclang/langserve/dev_server.jac +0 -201
  98. /jaclang/{langserve/tests/server_test/code_test.py → tests/fixtures/py2jac_empty.py} +0 -0
  99. {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/entry_points.txt +0 -0
@@ -16,7 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  import ast as py_ast
18
18
  import os
19
- from typing import Optional, Sequence, TYPE_CHECKING, TypeAlias, TypeVar
19
+ from typing import Optional, Sequence, TYPE_CHECKING, TypeAlias, TypeVar, cast
20
20
 
21
21
  import jaclang.compiler.unitree as uni
22
22
  from jaclang.compiler.constant import Tokens as Tok
@@ -96,6 +96,8 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
96
96
  body: list[stmt]
97
97
  type_ignores: list[TypeIgnore]
98
98
  """
99
+ if not node.body:
100
+ return uni.Module.make_stub(inject_src=self.ir_in)
99
101
  elements: list[uni.UniNode] = [self.convert(i) for i in node.body]
100
102
  elements[0] = (
101
103
  elements[0].expr
@@ -2091,6 +2093,7 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
2091
2093
  """Process python node.
2092
2094
 
2093
2095
  class arguments(AST):
2096
+ posonlyargs: list[arg]
2094
2097
  args: list[arg]
2095
2098
  vararg: arg | None
2096
2099
  kwonlyargs: list[arg]
@@ -2098,9 +2101,23 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
2098
2101
  kwarg: arg | None
2099
2102
  defaults: list[expr]
2100
2103
  """
2101
- args = [self.convert(arg) for arg in node.args]
2104
+
2105
+ def _apply_kind(params: list, kind: uni.ParamKind) -> list:
2106
+ for param in params:
2107
+ cast(uni.ParamVar, param).param_kind = kind
2108
+ return params
2109
+
2110
+ posonlyargs = _apply_kind(
2111
+ [self.convert(arg) for arg in node.posonlyargs], uni.ParamKind.POSONLY
2112
+ )
2113
+ args = _apply_kind(
2114
+ [self.convert(arg) for arg in node.args], uni.ParamKind.NORMAL
2115
+ )
2116
+
2102
2117
  vararg = self.convert(node.vararg) if node.vararg else None
2118
+
2103
2119
  if vararg and isinstance(vararg, uni.ParamVar):
2120
+ vararg.param_kind = uni.ParamKind.VARARG
2104
2121
  vararg.unpack = uni.Token(
2105
2122
  orig_src=self.orig_src,
2106
2123
  name=Tok.STAR_MUL,
@@ -2113,7 +2130,10 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
2113
2130
  pos_end=0,
2114
2131
  )
2115
2132
  vararg.add_kids_left([vararg.unpack])
2116
- kwonlyargs = [self.convert(arg) for arg in node.kwonlyargs]
2133
+
2134
+ kwonlyargs = _apply_kind(
2135
+ [self.convert(arg) for arg in node.kwonlyargs], uni.ParamKind.KWONLY
2136
+ )
2117
2137
  for i in range(len(kwonlyargs)):
2118
2138
  kwa = kwonlyargs[i]
2119
2139
  kwd = node.kw_defaults[i]
@@ -2127,6 +2147,7 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
2127
2147
  kwa.add_kids_right([kwa.value])
2128
2148
  kwarg = self.convert(node.kwarg) if node.kwarg else None
2129
2149
  if kwarg and isinstance(kwarg, uni.ParamVar):
2150
+ kwarg.param_kind = uni.ParamKind.KWARG
2130
2151
  kwarg.unpack = uni.Token(
2131
2152
  orig_src=self.orig_src,
2132
2153
  name=Tok.STAR_POW,
@@ -2140,29 +2161,42 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
2140
2161
  )
2141
2162
  kwarg.add_kids_left([kwarg.unpack])
2142
2163
  defaults = [self.convert(expr) for expr in node.defaults]
2143
- params = [*args]
2144
- for param, default in zip(params[::-1], defaults[::-1]):
2145
- if isinstance(default, uni.Expr) and isinstance(param, uni.ParamVar):
2146
- param.value = default
2147
- param.add_kids_right([default])
2148
- if vararg:
2149
- params.append(vararg)
2150
- params += kwonlyargs
2151
- if kwarg:
2152
- params.append(kwarg)
2153
- params += defaults
2154
-
2155
- valid_params = [param for param in params if isinstance(param, uni.ParamVar)]
2156
- if valid_params:
2157
- fs_params = valid_params
2164
+ # iterate reverse to match from the end
2165
+ for para in [*posonlyargs, *args][::-1]:
2166
+ if not defaults:
2167
+ break
2168
+ default = defaults.pop()
2169
+ if (
2170
+ default
2171
+ and isinstance(para, uni.ParamVar)
2172
+ and isinstance(default, uni.Expr)
2173
+ ):
2174
+ para.value = default
2175
+ para.add_kids_right([para.value])
2176
+
2177
+ if kwonlyargs or args or posonlyargs or vararg or kwarg:
2178
+ kids = []
2179
+ kids.extend(posonlyargs) if posonlyargs else None
2180
+ kids.extend(args) if args else None
2181
+ kids.append(vararg) if vararg else None
2182
+ kids.extend(kwonlyargs) if kwonlyargs else None
2183
+ kids.append(kwarg) if kwarg else None
2158
2184
  return uni.FuncSignature(
2159
- params=fs_params,
2185
+ posonly_params=posonlyargs,
2186
+ params=args,
2187
+ varargs=vararg,
2188
+ kwonlyargs=kwonlyargs,
2189
+ kwargs=kwarg,
2160
2190
  return_type=None,
2161
- kid=fs_params,
2191
+ kid=kids,
2162
2192
  )
2163
2193
  else:
2164
2194
  return uni.FuncSignature(
2195
+ posonly_params=posonlyargs,
2165
2196
  params=[],
2197
+ varargs=vararg,
2198
+ kwonlyargs=kwonlyargs,
2199
+ kwargs=kwarg,
2166
2200
  return_type=None,
2167
2201
  kid=[self.operator(Tok.LPAREN, "("), self.operator(Tok.RPAREN, ")")],
2168
2202
  )
@@ -75,7 +75,7 @@ class SymTabBuildPass(UniPass):
75
75
  def exit_module_path(self, node: uni.ModulePath) -> None:
76
76
  if node.alias:
77
77
  node.alias.sym_tab.def_insert(node.alias, single_decl="import")
78
- elif node.path and isinstance(node.path[0], uni.Name):
78
+ elif node.path:
79
79
  if node.parent_of_type(uni.Import) and not (
80
80
  node.parent_of_type(uni.Import).from_loc
81
81
  and node.parent_of_type(uni.Import).is_jac
@@ -0,0 +1,2 @@
1
+
2
+ import from imported_sym { foo }
@@ -0,0 +1,6 @@
1
+ import from import_sym { foo }
2
+
3
+ with entry {
4
+ a: str = foo(); # <-- Ok
5
+ b: int = foo(); # <-- Error
6
+ }
@@ -0,0 +1,5 @@
1
+
2
+
3
+ def foo() -> str {
4
+ return "foo";
5
+ }
@@ -0,0 +1,37 @@
1
+
2
+ class Foo {
3
+ def bar(self: Foo, a:int) -> None {}
4
+ static def baz(self: int, a: int) -> None {}
5
+ }
6
+
7
+ def foo(a: int, b: int, /, c:int, d:int, *args:int, e:int, f:int, **kwargs:int) -> None {}
8
+ def bar(a: int, b: int, /, c:int, d:int, e:int, f:int) -> None {}
9
+ def baz(a: int, *, b:int) -> None {}
10
+
11
+ with entry {
12
+ f = Foo();
13
+
14
+ f.bar();
15
+ f.bar(1);
16
+ f.bar(1, 2);
17
+
18
+ f.baz();
19
+ f.baz(1);
20
+ f.baz(1, 2);
21
+
22
+ foo(1, 2, 3, d=4, e=5, f=6, g=7, h=8); # c is positional and d is named
23
+ foo(1, 2, 3, 4, 5, 6, 7, 8, e=5, f=6); # marching extra with *args
24
+ foo(1, 2, d=3, e=4, f=5, c=4); # order does not matter for named
25
+
26
+ foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
27
+ foo(1, b=2, c=3, d=4, e=5, f=6); # b is positional only
28
+
29
+ bar(1, 2, 3, 4, 5, f=6);
30
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
31
+ bar(1, 2, 3, 4, 5, 6, c=3); # already matched
32
+ bar(1, 2, 3, 4, 5, 6, h=1); # h is not matched
33
+
34
+ baz(a=1, b=2);
35
+ baz(1, b=2); # a can be both positional and keyword
36
+ baz(1, 2); # 'b' can only be keyword arg
37
+ }
@@ -0,0 +1,18 @@
1
+
2
+ class Foo {
3
+ def first_is_self(self: Foo) -> None {}
4
+
5
+ def with_default_args(self: Foo, a:int, b:int=42) -> None {}
6
+ }
7
+
8
+
9
+ with entry {
10
+ f = Foo();
11
+ f.first_is_self(); # <-- Ok
12
+ f.first_is_self(f); # <-- Error
13
+
14
+ f.with_default_args(1); # <-- Ok
15
+ f.with_default_args(1, 2); # <-- Ok
16
+ f.with_default_args(1, 2, 3); # <-- Error
17
+ f.with_default_args(); # <-- Error
18
+ }
@@ -0,0 +1,18 @@
1
+
2
+ obj Animal {}
3
+ obj Cat(Animal) {}
4
+ obj Lion(Cat) {}
5
+
6
+ obj NotAnimal {}
7
+
8
+ def animal_func(a: Animal) -> None {}
9
+
10
+ with entry {
11
+ cat: Cat = Cat();
12
+ lion: Lion = Lion();
13
+ not_animal: NotAnimal = NotAnimal();
14
+
15
+ animal_func(cat); # <-- Ok
16
+ animal_func(lion); # <-- Ok
17
+ animal_func(not_animal); # <-- Error
18
+ }
@@ -0,0 +1,7 @@
1
+
2
+
3
+ with entry {
4
+ pi = 3.14; # <-- infer the type
5
+ f: float = pi; # <-- OK
6
+ s: str = pi; # <-- Error
7
+ }
@@ -0,0 +1,11 @@
1
+
2
+ node A {}
3
+ node B {}
4
+
5
+ def foo(a: A) -> None { }
6
+
7
+
8
+ with entry {
9
+ foo(A()); # <-- Ok
10
+ foo(B()); # <-- Error
11
+ }
@@ -0,0 +1,9 @@
1
+
2
+ node Foo {
3
+ has i: int = 0;
4
+
5
+ def foo() {
6
+ y: int = self.i; # <-- Ok
7
+ x: str = self.i; # <-- Error
8
+ }
9
+ }
@@ -0,0 +1,42 @@
1
+
2
+ # -----------------------------------------------------------------------------
3
+ # Simple Inheritance
4
+ # -----------------------------------------------------------------------------
5
+
6
+ node Parent {
7
+ has val: int = 42;
8
+ }
9
+
10
+ node Child (Parent) {
11
+
12
+ }
13
+
14
+
15
+ with entry {
16
+ c = Child();
17
+ c.val = 42; # <-- Ok
18
+ c.val = "str"; # <-- Error
19
+ }
20
+
21
+
22
+ # -----------------------------------------------------------------------------
23
+ # A Complex Inheritance
24
+ # -----------------------------------------------------------------------------
25
+
26
+ node Animal {
27
+ has name: str;
28
+ }
29
+
30
+ node Cat(Animal) {
31
+ }
32
+
33
+ node Lion(Cat) {
34
+ }
35
+
36
+
37
+ with entry {
38
+ l = Lion();
39
+
40
+ l.name = "Simba"; # <-- Ok
41
+ l.name = 42; # <-- Error
42
+ }
@@ -0,0 +1,43 @@
1
+ import torch;
2
+ class Cfg {
3
+ def __init__(self: Cfg, max_position_embeddings: Any = 8, has_original: Any = False) {
4
+ self.max_position_embeddings = max_position_embeddings;
5
+ if has_original {
6
+ self.original_max_position_embeddings = (max_position_embeddings // 2);
7
+ }
8
+ }
9
+ }
10
+ class ToyModel(torch.nn.Module) {
11
+ def __init__(self: ToyModel, cfg: Cfg) {
12
+ super.init();
13
+ self.config = cfg;
14
+ self.long_inv_freq = torch.tensor([10.0, 20.0, 30.0]);
15
+ self.original_inv_freq = torch.tensor([1.0, 2.0, 3.0]);
16
+ }
17
+ def rope_init_fn(self: ToyModel, cfg: Any, device: Any, seq_len: int) {
18
+ return (torch.arange(1, 4, dtype=torch.float32, device=device), None);
19
+ }
20
+ def _longrope_frequency_update(
21
+ self: ToyModel,
22
+ position_ids: Any,
23
+ device: Any = 'cpu'
24
+ ) {
25
+ seq_len = (torch.max(position_ids) + 1);
26
+ if hasattr(self.config, 'original_max_position_embeddings') {
27
+ original_max_position_embeddings = self.config.original_max_position_embeddings;
28
+ } else {
29
+ original_max_position_embeddings = self.config.max_position_embeddings;
30
+ }
31
+ if (seq_len > original_max_position_embeddings) {
32
+ self.register_buffer('inv_freq', self.long_inv_freq, persistent=False);
33
+ } else {
34
+ self.register_buffer(
35
+ 'inv_freq',
36
+ self.original_inv_freq.to(device),
37
+ persistent=False
38
+ );
39
+ }
40
+ return self.inv_freq;
41
+ }
42
+ }
43
+
@@ -0,0 +1,13 @@
1
+ import torch;
2
+
3
+ def with_breaks(a: Any, b: Any) {
4
+ x = (a / (torch.abs(a) + 1));
5
+ b = (b * 1.414);
6
+ if (b.sum() < 0) {
7
+ b = (b * -1);
8
+ } else{
9
+ b = (b * -2);
10
+ }
11
+
12
+ return (x * b);
13
+ }
@@ -0,0 +1,11 @@
1
+ import torch;
2
+
3
+ def with_breaks(a: Any, b: Any) {
4
+ x = (a / (torch.abs(a) + 1));
5
+ b = (b * 1.414);
6
+ if (b.sum() < 0) {
7
+ return x * (b * -1);
8
+ } else{
9
+ return x * (b * -2);
10
+ }
11
+ }
@@ -28,6 +28,18 @@ class TypeCheckerPassTests(TestCase):
28
28
  ^^^^^^^^^^^^^^^^^^^^^^
29
29
  """, program.errors_had[1].pretty_print())
30
30
 
31
+ def test_float_types(self) -> None:
32
+ program = JacProgram()
33
+ mod = program.compile(self.fixture_abs_path("checker_float.jac"))
34
+ TypeCheckPass(ir_in=mod, prog=program)
35
+ self.assertEqual(len(program.errors_had), 1)
36
+ self._assert_error_pretty_found("""
37
+ f: float = pi; # <-- OK
38
+ s: str = pi; # <-- Error
39
+ ^^^^^^^^^^^
40
+ """, program.errors_had[0].pretty_print())
41
+
42
+
31
43
  def test_infer_type_of_assignment(self) -> None:
32
44
  program = JacProgram()
33
45
  mod = program.compile(self.fixture_abs_path("infer_type_assignment.jac"))
@@ -49,6 +61,17 @@ class TypeCheckerPassTests(TestCase):
49
61
  ^^^^^^^^^^^^^^^^^^^
50
62
  """, program.errors_had[0].pretty_print())
51
63
 
64
+ def test_imported_sym(self) -> None:
65
+ program = JacProgram()
66
+ mod = program.compile(self.fixture_abs_path("checker/import_sym_test.jac"))
67
+ TypeCheckPass(ir_in=mod, prog=program)
68
+ self.assertEqual(len(program.errors_had), 1)
69
+ self._assert_error_pretty_found("""
70
+ a: str = foo(); # <-- Ok
71
+ b: int = foo(); # <-- Error
72
+ ^^^^^^^^^^^^^^
73
+ """, program.errors_had[0].pretty_print())
74
+
52
75
  def test_member_access_type_infered(self) -> None:
53
76
  program = JacProgram()
54
77
  mod = program.compile(self.fixture_abs_path("member_access_type_inferred.jac"))
@@ -59,6 +82,22 @@ class TypeCheckerPassTests(TestCase):
59
82
  ^^^^^^^^^
60
83
  """, program.errors_had[0].pretty_print())
61
84
 
85
+ def test_inherited_symbol(self) -> None:
86
+ program = JacProgram()
87
+ mod = program.compile(self.fixture_abs_path("checker_sym_inherit.jac"))
88
+ TypeCheckPass(ir_in=mod, prog=program)
89
+ self.assertEqual(len(program.errors_had), 2)
90
+ self._assert_error_pretty_found("""
91
+ c.val = 42; # <-- Ok
92
+ c.val = "str"; # <-- Error
93
+ ^^^^^^^^^^^^^
94
+ """, program.errors_had[0].pretty_print())
95
+ self._assert_error_pretty_found("""
96
+ l.name = "Simba"; # <-- Ok
97
+ l.name = 42; # <-- Error
98
+ ^^^^^^^^^^^
99
+ """, program.errors_had[1].pretty_print())
100
+
62
101
  def test_import_symbol_type_infer(self) -> None:
63
102
  program = JacProgram()
64
103
  mod = program.compile(self.fixture_abs_path("import_symbol_type_infer.jac"))
@@ -104,6 +143,144 @@ class TypeCheckerPassTests(TestCase):
104
143
  ^^^^^^^^^^^^^^^^
105
144
  """, program.errors_had[0].pretty_print())
106
145
 
146
+ def test_arity(self) -> None:
147
+ path = self.fixture_abs_path("checker_arity.jac")
148
+ program = JacProgram()
149
+ mod = program.compile(path)
150
+ TypeCheckPass(ir_in=mod, prog=program)
151
+ self.assertEqual(len(program.errors_had), 3)
152
+ self._assert_error_pretty_found("""
153
+ f.first_is_self(f); # <-- Error
154
+ ^
155
+ """, program.errors_had[0].pretty_print())
156
+ self._assert_error_pretty_found("""
157
+ f.with_default_args(1, 2, 3); # <-- Error
158
+ ^
159
+ """, program.errors_had[1].pretty_print())
160
+ self._assert_error_pretty_found("""
161
+ f.with_default_args(); # <-- Error
162
+ ^^^^^^^^^^^^^^^^^^^^^
163
+ """, program.errors_had[2].pretty_print())
164
+
165
+ def test_param_types(self) -> None:
166
+ path = self.fixture_abs_path("checker_param_types.jac")
167
+ program = JacProgram()
168
+ mod = program.compile(path)
169
+ TypeCheckPass(ir_in=mod, prog=program)
170
+ self.assertEqual(len(program.errors_had), 1)
171
+ self._assert_error_pretty_found("""
172
+ foo(A()); # <-- Ok
173
+ foo(B()); # <-- Error
174
+ ^^^
175
+ """, program.errors_had[0].pretty_print())
176
+
177
+ def test_param_arg_match(self) -> None:
178
+ program = JacProgram()
179
+ path = self.fixture_abs_path("checker_arg_param_match.jac")
180
+ mod = program.compile(path)
181
+ TypeCheckPass(ir_in=mod, prog=program)
182
+ self.assertEqual(len(program.errors_had), 13)
183
+
184
+ expected_errors = [
185
+ """
186
+ Not all required parameters were provided in the function call: 'a'
187
+ f = Foo();
188
+ f.bar();
189
+ ^^^^^^^
190
+ """,
191
+ """
192
+ Too many positional arguments
193
+ f.bar();
194
+ f.bar(1);
195
+ f.bar(1, 2);
196
+ ^
197
+ """,
198
+ """
199
+ Not all required parameters were provided in the function call: 'self', 'a'
200
+ f.bar(1, 2);
201
+ f.baz();
202
+ ^^^^^^^
203
+ """,
204
+ """
205
+ Not all required parameters were provided in the function call: 'a'
206
+ f.baz();
207
+ f.baz(1);
208
+ ^^^^^^^^
209
+ """,
210
+ """
211
+ Not all required parameters were provided in the function call: 'f'
212
+ foo(1, 2, d=3, e=4, f=5, c=4); # order does not matter for named
213
+ foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
214
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
215
+ """,
216
+ """
217
+ Positional only parameter 'b' cannot be matched with a named argument
218
+ foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
219
+ foo(1, b=2, c=3, d=4, e=5, f=6); # b is positional only
220
+ ^^^
221
+ """,
222
+ """
223
+ Too many positional arguments
224
+ bar(1, 2, 3, 4, 5, f=6);
225
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
226
+ ^
227
+ """,
228
+ """
229
+ Too many positional arguments
230
+ bar(1, 2, 3, 4, 5, f=6);
231
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
232
+ ^
233
+ """,
234
+ """
235
+ Too many positional arguments
236
+ bar(1, 2, 3, 4, 5, f=6);
237
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
238
+ ^
239
+ """,
240
+ """
241
+ Parameter 'c' already matched
242
+ bar(1, 2, 3, 4, 5, f=6);
243
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
244
+ bar(1, 2, 3, 4, 5, 6, c=3); # already matched
245
+ ^^^
246
+ """,
247
+ """
248
+ Named argument 'h' does not match any parameter
249
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
250
+ bar(1, 2, 3, 4, 5, 6, c=3); # already matched
251
+ bar(1, 2, 3, 4, 5, 6, h=1); # h is not matched
252
+ ^^^
253
+ """,
254
+ """
255
+ Too many positional arguments
256
+ baz(a=1, b=2);
257
+ baz(1, b=2); # a can be both positional and keyword
258
+ baz(1, 2); # 'b' can only be keyword arg
259
+ ^
260
+ """,
261
+ """
262
+ Not all required parameters were provided in the function call: 'b'
263
+ baz(a=1, b=2);
264
+ baz(1, b=2); # a can be both positional and keyword
265
+ baz(1, 2); # 'b' can only be keyword arg
266
+ ^^^^^^^^^
267
+ """,
268
+ ]
269
+
270
+ for i, expected in enumerate(expected_errors):
271
+ self._assert_error_pretty_found(expected, program.errors_had[i].pretty_print())
272
+
273
+ def test_self_type_inference(self) -> None:
274
+ path = self.fixture_abs_path("checker_self_type.jac")
275
+ program = JacProgram()
276
+ mod = program.compile(path)
277
+ TypeCheckPass(ir_in=mod, prog=program)
278
+ self.assertEqual(len(program.errors_had), 1)
279
+ self._assert_error_pretty_found("""
280
+ x: str = self.i; # <-- Error
281
+ ^^^^^^^^^^^^^^^
282
+ """, program.errors_had[0].pretty_print())
283
+
107
284
  def test_binary_op(self) -> None:
108
285
  program = JacProgram()
109
286
  mod = program.compile(self.fixture_abs_path("checker_binary_op.jac"))
@@ -140,6 +317,19 @@ class TypeCheckerPassTests(TestCase):
140
317
  ^^^^^^^^^^^^^^
141
318
  """, program.errors_had[0].pretty_print())
142
319
 
320
+ def test_checker_cat_is_animal(self) -> None:
321
+ path = self.fixture_abs_path("checker_cat_is_animal.jac")
322
+ program = JacProgram()
323
+ mod = program.compile(path)
324
+ TypeCheckPass(ir_in=mod, prog=program)
325
+ self.assertEqual(len(program.errors_had), 1)
326
+ self._assert_error_pretty_found("""
327
+ animal_func(cat); # <-- Ok
328
+ animal_func(lion); # <-- Ok
329
+ animal_func(not_animal); # <-- Error
330
+ ^^^^^^^^^^
331
+ """, program.errors_had[0].pretty_print())
332
+
143
333
  def test_checker_import_missing_module(self) -> None:
144
334
  path = self.fixture_abs_path("checker_import_missing_module.jac")
145
335
  program = JacProgram()
@@ -0,0 +1,56 @@
1
+ """Test ast build pass module."""
2
+
3
+ import io
4
+ import os
5
+ import sys
6
+
7
+ from jaclang.compiler.program import JacProgram, py_code_gen
8
+ from jaclang.utils.test import TestCase
9
+ from jaclang.settings import settings
10
+ from jaclang.compiler.passes.main import PreDynamoPass
11
+
12
+ class PreDynamoPassTests(TestCase):
13
+ """Test pass module."""
14
+
15
+ TargetPass = PreDynamoPass
16
+
17
+ def setUp(self) -> None:
18
+ """Set up test."""
19
+ settings.predynamo_pass = True
20
+ return super().setUp()
21
+
22
+ def tearDown(self) -> None:
23
+ """Tear down test."""
24
+ settings.predynamo_pass = False
25
+ # Remove PreDynamoPass from global py_code_gen list if it was added
26
+ if PreDynamoPass in py_code_gen:
27
+ py_code_gen.remove(PreDynamoPass)
28
+ return super().tearDown()
29
+
30
+ def test_predynamo_where_assign(self) -> None:
31
+ """Test torch.where transformation."""
32
+ captured_output = io.StringIO()
33
+ sys.stdout = captured_output
34
+ settings.predynamo_pass = True
35
+ code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_assign.jac"))
36
+ sys.stdout = sys.__stdout__
37
+ self.assertIn("torch.where", code_gen.unparse())
38
+
39
+ def test_predynamo_where_return(self) -> None:
40
+ """Test torch.where transformation."""
41
+ captured_output = io.StringIO()
42
+ sys.stdout = captured_output
43
+ code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_return.jac"))
44
+ sys.stdout = sys.__stdout__
45
+ self.assertIn("torch.where", code_gen.unparse())
46
+
47
+
48
+ def test_predynamo_fix3(self) -> None:
49
+ """Test torch.where transformation."""
50
+ captured_output = io.StringIO()
51
+ sys.stdout = captured_output
52
+ code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_fix3.jac"))
53
+ sys.stdout = sys.__stdout__
54
+ unparsed_code = code_gen.unparse()
55
+ self.assertIn("__inv_freq = torch.where(", unparsed_code)
56
+ self.assertIn("self.register_buffer('inv_freq', __inv_freq, persistent=False);", unparsed_code)