jaclang 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of jaclang might be problematic. Click here for more details.

Files changed (103) hide show
  1. jaclang/cli/cli.md +3 -3
  2. jaclang/cli/cli.py +37 -37
  3. jaclang/cli/cmdreg.py +45 -140
  4. jaclang/compiler/constant.py +0 -1
  5. jaclang/compiler/jac.lark +3 -6
  6. jaclang/compiler/larkparse/jac_parser.py +2 -2
  7. jaclang/compiler/parser.py +213 -34
  8. jaclang/compiler/passes/main/__init__.py +2 -4
  9. jaclang/compiler/passes/main/def_use_pass.py +0 -4
  10. jaclang/compiler/passes/main/predynamo_pass.py +221 -0
  11. jaclang/compiler/passes/main/pyast_gen_pass.py +83 -55
  12. jaclang/compiler/passes/main/pyast_load_pass.py +66 -40
  13. jaclang/compiler/passes/main/sym_tab_build_pass.py +1 -1
  14. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym.jac +2 -0
  15. jaclang/compiler/passes/main/tests/fixtures/checker/import_sym_test.jac +6 -0
  16. jaclang/compiler/passes/main/tests/fixtures/checker/imported_sym.jac +5 -0
  17. jaclang/compiler/passes/main/tests/fixtures/checker_arg_param_match.jac +37 -0
  18. jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +18 -0
  19. jaclang/compiler/passes/main/tests/fixtures/checker_binary_op.jac +21 -0
  20. jaclang/compiler/passes/main/tests/fixtures/checker_call_expr_class.jac +12 -0
  21. jaclang/compiler/passes/main/tests/fixtures/checker_cat_is_animal.jac +18 -0
  22. jaclang/compiler/passes/main/tests/fixtures/checker_cyclic_symbol.jac +4 -0
  23. jaclang/compiler/passes/main/tests/fixtures/checker_expr_call.jac +9 -0
  24. jaclang/compiler/passes/main/tests/fixtures/checker_float.jac +7 -0
  25. jaclang/compiler/passes/main/tests/fixtures/checker_import_missing_module.jac +13 -0
  26. jaclang/compiler/passes/main/tests/fixtures/checker_magic_call.jac +17 -0
  27. jaclang/compiler/passes/main/tests/fixtures/checker_mod_path.jac +8 -0
  28. jaclang/compiler/passes/main/tests/fixtures/checker_param_types.jac +11 -0
  29. jaclang/compiler/passes/main/tests/fixtures/checker_self_type.jac +9 -0
  30. jaclang/compiler/passes/main/tests/fixtures/checker_sym_inherit.jac +42 -0
  31. jaclang/compiler/passes/main/tests/fixtures/predynamo_fix3.jac +43 -0
  32. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_assign.jac +13 -0
  33. jaclang/compiler/passes/main/tests/fixtures/predynamo_where_return.jac +11 -0
  34. jaclang/compiler/passes/main/tests/test_checker_pass.py +265 -0
  35. jaclang/compiler/passes/main/tests/test_predynamo_pass.py +57 -0
  36. jaclang/compiler/passes/main/type_checker_pass.py +36 -61
  37. jaclang/compiler/passes/tool/doc_ir_gen_pass.py +204 -44
  38. jaclang/compiler/passes/tool/jac_formatter_pass.py +119 -69
  39. jaclang/compiler/passes/tool/tests/fixtures/corelib_fmt.jac +3 -3
  40. jaclang/compiler/passes/tool/tests/fixtures/general_format_checks/triple_quoted_string.jac +4 -5
  41. jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +171 -11
  42. jaclang/compiler/passes/transform.py +12 -8
  43. jaclang/compiler/program.py +14 -6
  44. jaclang/compiler/tests/fixtures/jac_import_py_files.py +4 -0
  45. jaclang/compiler/tests/fixtures/jac_module.jac +3 -0
  46. jaclang/compiler/tests/fixtures/multiple_syntax_errors.jac +10 -0
  47. jaclang/compiler/tests/fixtures/python_module.py +1 -0
  48. jaclang/compiler/tests/test_importer.py +39 -0
  49. jaclang/compiler/tests/test_parser.py +49 -0
  50. jaclang/compiler/type_system/operations.py +104 -0
  51. jaclang/compiler/type_system/type_evaluator.py +470 -47
  52. jaclang/compiler/type_system/type_utils.py +246 -0
  53. jaclang/compiler/type_system/types.py +58 -2
  54. jaclang/compiler/unitree.py +79 -94
  55. jaclang/langserve/engine.jac +253 -230
  56. jaclang/langserve/server.jac +46 -15
  57. jaclang/langserve/tests/fixtures/circle.jac +3 -3
  58. jaclang/langserve/tests/fixtures/circle_err.jac +3 -3
  59. jaclang/langserve/tests/fixtures/circle_pure.test.jac +3 -3
  60. jaclang/langserve/tests/fixtures/completion_test_err.jac +10 -0
  61. jaclang/langserve/tests/server_test/circle_template.jac +80 -0
  62. jaclang/langserve/tests/server_test/glob_template.jac +4 -0
  63. jaclang/langserve/tests/server_test/test_lang_serve.py +154 -312
  64. jaclang/langserve/tests/server_test/utils.py +153 -116
  65. jaclang/langserve/tests/test_dev_server.py +1 -1
  66. jaclang/langserve/tests/test_server.py +30 -86
  67. jaclang/langserve/utils.jac +56 -63
  68. jaclang/runtimelib/machine.py +7 -0
  69. jaclang/runtimelib/meta_importer.py +27 -1
  70. jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
  71. jaclang/runtimelib/tests/fixtures/savable_object.jac +2 -2
  72. jaclang/settings.py +18 -14
  73. jaclang/tests/fixtures/abc_check.jac +3 -3
  74. jaclang/tests/fixtures/arch_rel_import_creation.jac +12 -12
  75. jaclang/tests/fixtures/chandra_bugs2.jac +3 -3
  76. jaclang/tests/fixtures/create_dynamic_archetype.jac +13 -13
  77. jaclang/tests/fixtures/jac_run_py_bugs.py +18 -0
  78. jaclang/tests/fixtures/jac_run_py_import.py +13 -0
  79. jaclang/tests/fixtures/lambda_arg_annotation.jac +15 -0
  80. jaclang/tests/fixtures/lambda_self.jac +18 -0
  81. jaclang/tests/fixtures/maxfail_run_test.jac +4 -4
  82. jaclang/tests/fixtures/params/param_syntax_err.jac +9 -0
  83. jaclang/tests/fixtures/params/test_complex_params.jac +42 -0
  84. jaclang/tests/fixtures/params/test_failing_kwonly.jac +207 -0
  85. jaclang/tests/fixtures/params/test_failing_posonly.jac +116 -0
  86. jaclang/tests/fixtures/params/test_failing_varargs.jac +300 -0
  87. jaclang/tests/fixtures/params/test_kwonly_params.jac +29 -0
  88. jaclang/tests/fixtures/py2jac_params.py +8 -0
  89. jaclang/tests/fixtures/run_test.jac +4 -4
  90. jaclang/tests/test_cli.py +103 -18
  91. jaclang/tests/test_language.py +74 -16
  92. jaclang/utils/helpers.py +47 -2
  93. jaclang/utils/module_resolver.py +11 -1
  94. jaclang/utils/test.py +8 -0
  95. jaclang/utils/treeprinter.py +0 -18
  96. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/METADATA +3 -3
  97. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/RECORD +99 -62
  98. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/WHEEL +1 -1
  99. jaclang/compiler/passes/main/inheritance_pass.py +0 -131
  100. jaclang/langserve/dev_engine.jac +0 -645
  101. jaclang/langserve/dev_server.jac +0 -201
  102. jaclang/langserve/tests/server_test/code_test.py +0 -0
  103. {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,17 @@
1
+
2
+ node Bar {}
3
+
4
+ node Foo {
5
+ def __call__() -> Bar {
6
+ return Bar();
7
+ }
8
+ }
9
+
10
+ def fn() -> Foo {
11
+ return Foo();
12
+ }
13
+
14
+ with entry{
15
+ b: Bar = fn()(); # <-- Ok
16
+ f: Foo = fn()(); # <-- Error
17
+ }
@@ -0,0 +1,8 @@
1
+
2
+
3
+ import jaclang.compiler.unitree as uni;
4
+
5
+
6
+ with entry{
7
+ a:int = uni.Module; # <-- Error
8
+ }
@@ -0,0 +1,11 @@
1
+
2
+ node A {}
3
+ node B {}
4
+
5
+ def foo(a: A) -> None { }
6
+
7
+
8
+ with entry {
9
+ foo(A()); # <-- Ok
10
+ foo(B()); # <-- Error
11
+ }
@@ -0,0 +1,9 @@
1
+
2
+ node Foo {
3
+ has i: int = 0;
4
+
5
+ def foo() {
6
+ y: int = self.i; # <-- Ok
7
+ x: str = self.i; # <-- Error
8
+ }
9
+ }
@@ -0,0 +1,42 @@
1
+
2
+ # -----------------------------------------------------------------------------
3
+ # Simple Inheritance
4
+ # -----------------------------------------------------------------------------
5
+
6
+ node Parent {
7
+ has val: int = 42;
8
+ }
9
+
10
+ node Child (Parent) {
11
+
12
+ }
13
+
14
+
15
+ with entry {
16
+ c = Child();
17
+ c.val = 42; # <-- Ok
18
+ c.val = "str"; # <-- Error
19
+ }
20
+
21
+
22
+ # -----------------------------------------------------------------------------
23
+ # A Complex Inheritance
24
+ # -----------------------------------------------------------------------------
25
+
26
+ node Animal {
27
+ has name: str;
28
+ }
29
+
30
+ node Cat(Animal) {
31
+ }
32
+
33
+ node Lion(Cat) {
34
+ }
35
+
36
+
37
+ with entry {
38
+ l = Lion();
39
+
40
+ l.name = "Simba"; # <-- Ok
41
+ l.name = 42; # <-- Error
42
+ }
@@ -0,0 +1,43 @@
1
+ import torch;
2
+ class Cfg {
3
+ def __init__(self: Cfg, max_position_embeddings: Any = 8, has_original: Any = False) {
4
+ self.max_position_embeddings = max_position_embeddings;
5
+ if has_original {
6
+ self.original_max_position_embeddings = (max_position_embeddings // 2);
7
+ }
8
+ }
9
+ }
10
+ class ToyModel(torch.nn.Module) {
11
+ def __init__(self: ToyModel, cfg: Cfg) {
12
+ super.init();
13
+ self.config = cfg;
14
+ self.long_inv_freq = torch.tensor([10.0, 20.0, 30.0]);
15
+ self.original_inv_freq = torch.tensor([1.0, 2.0, 3.0]);
16
+ }
17
+ def rope_init_fn(self: ToyModel, cfg: Any, device: Any, seq_len: int) {
18
+ return (torch.arange(1, 4, dtype=torch.float32, device=device), None);
19
+ }
20
+ def _longrope_frequency_update(
21
+ self: ToyModel,
22
+ position_ids: Any,
23
+ device: Any = 'cpu'
24
+ ) {
25
+ seq_len = (torch.max(position_ids) + 1);
26
+ if hasattr(self.config, 'original_max_position_embeddings') {
27
+ original_max_position_embeddings = self.config.original_max_position_embeddings;
28
+ } else {
29
+ original_max_position_embeddings = self.config.max_position_embeddings;
30
+ }
31
+ if (seq_len > original_max_position_embeddings) {
32
+ self.register_buffer('inv_freq', self.long_inv_freq, persistent=False);
33
+ } else {
34
+ self.register_buffer(
35
+ 'inv_freq',
36
+ self.original_inv_freq.to(device),
37
+ persistent=False
38
+ );
39
+ }
40
+ return self.inv_freq;
41
+ }
42
+ }
43
+
@@ -0,0 +1,13 @@
1
+ import torch;
2
+
3
+ def with_breaks(a: Any, b: Any) {
4
+ x = (a / (torch.abs(a) + 1));
5
+ b = (b * 1.414);
6
+ if (b.sum() < 0) {
7
+ b = (b * -1);
8
+ } else{
9
+ b = (b * -2);
10
+ }
11
+
12
+ return (x * b);
13
+ }
@@ -0,0 +1,11 @@
1
+ import torch;
2
+
3
+ def with_breaks(a: Any, b: Any) {
4
+ x = (a / (torch.abs(a) + 1));
5
+ b = (b * 1.414);
6
+ if (b.sum() < 0) {
7
+ return x * (b * -1);
8
+ } else{
9
+ return x * (b * -2);
10
+ }
11
+ }
@@ -28,6 +28,18 @@ class TypeCheckerPassTests(TestCase):
28
28
  ^^^^^^^^^^^^^^^^^^^^^^
29
29
  """, program.errors_had[1].pretty_print())
30
30
 
31
+ def test_float_types(self) -> None:
32
+ program = JacProgram()
33
+ mod = program.compile(self.fixture_abs_path("checker_float.jac"))
34
+ TypeCheckPass(ir_in=mod, prog=program)
35
+ self.assertEqual(len(program.errors_had), 1)
36
+ self._assert_error_pretty_found("""
37
+ f: float = pi; # <-- OK
38
+ s: str = pi; # <-- Error
39
+ ^^^^^^^^^^^
40
+ """, program.errors_had[0].pretty_print())
41
+
42
+
31
43
  def test_infer_type_of_assignment(self) -> None:
32
44
  program = JacProgram()
33
45
  mod = program.compile(self.fixture_abs_path("infer_type_assignment.jac"))
@@ -49,6 +61,17 @@ class TypeCheckerPassTests(TestCase):
49
61
  ^^^^^^^^^^^^^^^^^^^
50
62
  """, program.errors_had[0].pretty_print())
51
63
 
64
+ def test_imported_sym(self) -> None:
65
+ program = JacProgram()
66
+ mod = program.compile(self.fixture_abs_path("checker/import_sym_test.jac"))
67
+ TypeCheckPass(ir_in=mod, prog=program)
68
+ self.assertEqual(len(program.errors_had), 1)
69
+ self._assert_error_pretty_found("""
70
+ a: str = foo(); # <-- Ok
71
+ b: int = foo(); # <-- Error
72
+ ^^^^^^^^^^^^^^
73
+ """, program.errors_had[0].pretty_print())
74
+
52
75
  def test_member_access_type_infered(self) -> None:
53
76
  program = JacProgram()
54
77
  mod = program.compile(self.fixture_abs_path("member_access_type_inferred.jac"))
@@ -59,6 +82,22 @@ class TypeCheckerPassTests(TestCase):
59
82
  ^^^^^^^^^
60
83
  """, program.errors_had[0].pretty_print())
61
84
 
85
+ def test_inherited_symbol(self) -> None:
86
+ program = JacProgram()
87
+ mod = program.compile(self.fixture_abs_path("checker_sym_inherit.jac"))
88
+ TypeCheckPass(ir_in=mod, prog=program)
89
+ self.assertEqual(len(program.errors_had), 2)
90
+ self._assert_error_pretty_found("""
91
+ c.val = 42; # <-- Ok
92
+ c.val = "str"; # <-- Error
93
+ ^^^^^^^^^^^^^
94
+ """, program.errors_had[0].pretty_print())
95
+ self._assert_error_pretty_found("""
96
+ l.name = "Simba"; # <-- Ok
97
+ l.name = 42; # <-- Error
98
+ ^^^^^^^^^^^
99
+ """, program.errors_had[1].pretty_print())
100
+
62
101
  def test_import_symbol_type_infer(self) -> None:
63
102
  program = JacProgram()
64
103
  mod = program.compile(self.fixture_abs_path("import_symbol_type_infer.jac"))
@@ -81,6 +120,232 @@ class TypeCheckerPassTests(TestCase):
81
120
  ^^^^^^^^^^^^^^
82
121
  """, program.errors_had[0].pretty_print())
83
122
 
123
+ def test_call_expr(self) -> None:
124
+ path = self.fixture_abs_path("checker_expr_call.jac")
125
+ program = JacProgram()
126
+ mod = program.compile(path)
127
+ TypeCheckPass(ir_in=mod, prog=program)
128
+ self.assertEqual(len(program.errors_had), 1)
129
+ self._assert_error_pretty_found("""
130
+ s: str = foo();
131
+ ^^^^^^^^^^^^^^
132
+ """, program.errors_had[0].pretty_print())
133
+
134
+ def test_call_expr_magic(self) -> None:
135
+ path = self.fixture_abs_path("checker_magic_call.jac")
136
+ program = JacProgram()
137
+ mod = program.compile(path)
138
+ TypeCheckPass(ir_in=mod, prog=program)
139
+ self.assertEqual(len(program.errors_had), 1)
140
+ self._assert_error_pretty_found("""
141
+ b: Bar = fn()(); # <-- Ok
142
+ f: Foo = fn()(); # <-- Error
143
+ ^^^^^^^^^^^^^^^^
144
+ """, program.errors_had[0].pretty_print())
145
+
146
+ def test_arity(self) -> None:
147
+ path = self.fixture_abs_path("checker_arity.jac")
148
+ program = JacProgram()
149
+ mod = program.compile(path)
150
+ TypeCheckPass(ir_in=mod, prog=program)
151
+ self.assertEqual(len(program.errors_had), 3)
152
+ self._assert_error_pretty_found("""
153
+ f.first_is_self(f); # <-- Error
154
+ ^
155
+ """, program.errors_had[0].pretty_print())
156
+ self._assert_error_pretty_found("""
157
+ f.with_default_args(1, 2, 3); # <-- Error
158
+ ^
159
+ """, program.errors_had[1].pretty_print())
160
+ self._assert_error_pretty_found("""
161
+ f.with_default_args(); # <-- Error
162
+ ^^^^^^^^^^^^^^^^^^^^^
163
+ """, program.errors_had[2].pretty_print())
164
+
165
+ def test_param_types(self) -> None:
166
+ path = self.fixture_abs_path("checker_param_types.jac")
167
+ program = JacProgram()
168
+ mod = program.compile(path)
169
+ TypeCheckPass(ir_in=mod, prog=program)
170
+ self.assertEqual(len(program.errors_had), 1)
171
+ self._assert_error_pretty_found("""
172
+ foo(A()); # <-- Ok
173
+ foo(B()); # <-- Error
174
+ ^^^
175
+ """, program.errors_had[0].pretty_print())
176
+
177
+ def test_param_arg_match(self) -> None:
178
+ path = self.fixture_abs_path("checker_param_types.jac")
179
+ program = JacProgram()
180
+ path = self.fixture_abs_path("checker_arg_param_match.jac")
181
+ mod = program.compile(path)
182
+ TypeCheckPass(ir_in=mod, prog=program)
183
+ self.assertEqual(len(program.errors_had), 13)
184
+
185
+ expected_errors = [
186
+ """
187
+ Not all required parameters were provided in the function call: 'a'
188
+ f = Foo();
189
+ f.bar();
190
+ ^^^^^^^
191
+ """,
192
+ """
193
+ Too many positional arguments
194
+ f.bar();
195
+ f.bar(1);
196
+ f.bar(1, 2);
197
+ ^
198
+ """,
199
+ """
200
+ Not all required parameters were provided in the function call: 'self', 'a'
201
+ f.bar(1, 2);
202
+ f.baz();
203
+ ^^^^^^^
204
+ """,
205
+ """
206
+ Not all required parameters were provided in the function call: 'a'
207
+ f.baz();
208
+ f.baz(1);
209
+ ^^^^^^^^
210
+ """,
211
+ """
212
+ Not all required parameters were provided in the function call: 'f'
213
+ foo(1, 2, d=3, e=4, f=5, c=4); # order does not matter for named
214
+ foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
215
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
216
+ """,
217
+ """
218
+ Positional only parameter 'b' cannot be matched with a named argument
219
+ foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
220
+ foo(1, b=2, c=3, d=4, e=5, f=6); # b is positional only
221
+ ^^^
222
+ """,
223
+ """
224
+ Too many positional arguments
225
+ bar(1, 2, 3, 4, 5, f=6);
226
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
227
+ ^
228
+ """,
229
+ """
230
+ Too many positional arguments
231
+ bar(1, 2, 3, 4, 5, f=6);
232
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
233
+ ^
234
+ """,
235
+ """
236
+ Too many positional arguments
237
+ bar(1, 2, 3, 4, 5, f=6);
238
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
239
+ ^
240
+ """,
241
+ """
242
+ Parameter 'c' already matched
243
+ bar(1, 2, 3, 4, 5, f=6);
244
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
245
+ bar(1, 2, 3, 4, 5, 6, c=3); # already matched
246
+ ^^^
247
+ """,
248
+ """
249
+ Named argument 'h' does not match any parameter
250
+ bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
251
+ bar(1, 2, 3, 4, 5, 6, c=3); # already matched
252
+ bar(1, 2, 3, 4, 5, 6, h=1); # h is not matched
253
+ ^^^
254
+ """,
255
+ """
256
+ Too many positional arguments
257
+ baz(a=1, b=2);
258
+ baz(1, b=2); # a can be both positional and keyword
259
+ baz(1, 2); # 'b' can only be keyword arg
260
+ ^
261
+ """,
262
+ """
263
+ Not all required parameters were provided in the function call: 'b'
264
+ baz(a=1, b=2);
265
+ baz(1, b=2); # a can be both positional and keyword
266
+ baz(1, 2); # 'b' can only be keyword arg
267
+ ^^^^^^^^^
268
+ """,
269
+ ]
270
+
271
+ for i, expected in enumerate(expected_errors):
272
+ self._assert_error_pretty_found(expected, program.errors_had[i].pretty_print())
273
+
274
+ def test_self_type_inference(self) -> None:
275
+ path = self.fixture_abs_path("checker_self_type.jac")
276
+ program = JacProgram()
277
+ mod = program.compile(path)
278
+ TypeCheckPass(ir_in=mod, prog=program)
279
+ self.assertEqual(len(program.errors_had), 1)
280
+ self._assert_error_pretty_found("""
281
+ x: str = self.i; # <-- Error
282
+ ^^^^^^^^^^^^^^^
283
+ """, program.errors_had[0].pretty_print())
284
+
285
+ def test_binary_op(self) -> None:
286
+ program = JacProgram()
287
+ mod = program.compile(self.fixture_abs_path("checker_binary_op.jac"))
288
+ TypeCheckPass(ir_in=mod, prog=program)
289
+ self.assertEqual(len(program.errors_had), 2)
290
+ self._assert_error_pretty_found("""
291
+ r2: A = a + a; # <-- Error
292
+ ^^^^^^^^^^^^^
293
+ """, program.errors_had[0].pretty_print())
294
+ self._assert_error_pretty_found("""
295
+ r4: str = (a+a) * B(); # <-- Error
296
+ ^^^^^^^^^^^^^^^^^^^^^
297
+ """, program.errors_had[1].pretty_print())
298
+
299
+ def test_checker_call_expr_class(self) -> None:
300
+ path = self.fixture_abs_path("checker_call_expr_class.jac")
301
+ program = JacProgram()
302
+ mod = program.compile(path)
303
+ TypeCheckPass(ir_in=mod, prog=program)
304
+ self.assertEqual(len(program.errors_had), 1)
305
+ self._assert_error_pretty_found("""
306
+ inst.i = 'str'; # <-- Error
307
+ ^^^^^^^^^^^^^^
308
+ """, program.errors_had[0].pretty_print())
309
+
310
+ def test_checker_mod_path(self) -> None:
311
+ path = self.fixture_abs_path("checker_mod_path.jac")
312
+ program = JacProgram()
313
+ mod = program.compile(path)
314
+ TypeCheckPass(ir_in=mod, prog=program)
315
+ self.assertEqual(len(program.errors_had), 1)
316
+ self._assert_error_pretty_found("""
317
+ a:int = uni.Module; # <-- Error
318
+ ^^^^^^^^^^^^^^
319
+ """, program.errors_had[0].pretty_print())
320
+
321
+ def test_checker_cat_is_animal(self) -> None:
322
+ path = self.fixture_abs_path("checker_cat_is_animal.jac")
323
+ program = JacProgram()
324
+ mod = program.compile(path)
325
+ TypeCheckPass(ir_in=mod, prog=program)
326
+ self.assertEqual(len(program.errors_had), 1)
327
+ self._assert_error_pretty_found("""
328
+ animal_func(cat); # <-- Ok
329
+ animal_func(lion); # <-- Ok
330
+ animal_func(not_animal); # <-- Error
331
+ ^^^^^^^^^^
332
+ """, program.errors_had[0].pretty_print())
333
+
334
+ def test_checker_import_missing_module(self) -> None:
335
+ path = self.fixture_abs_path("checker_import_missing_module.jac")
336
+ program = JacProgram()
337
+ mod = program.compile(path)
338
+ TypeCheckPass(ir_in=mod, prog=program)
339
+ self.assertEqual(len(program.errors_had), 0)
340
+
341
+ def test_cyclic_symbol(self) -> None:
342
+ path = self.fixture_abs_path("checker_cyclic_symbol.jac")
343
+ program = JacProgram()
344
+ mod = program.compile(path)
345
+ # This will result in a stack overflow if not handled properly.
346
+ # So the fact that it has 0 errors means it passed.
347
+ TypeCheckPass(ir_in=mod, prog=program)
348
+ self.assertEqual(len(program.errors_had), 0)
84
349
 
85
350
  def _assert_error_pretty_found(self, needle: str, haystack: str) -> None:
86
351
  for line in [line.strip() for line in needle.splitlines() if line.strip()]:
@@ -0,0 +1,57 @@
1
+ """Test ast build pass module."""
2
+
3
+ import io
4
+ import os
5
+ import sys
6
+
7
+ from jaclang.compiler.program import JacProgram
8
+ from jaclang.utils.test import TestCase
9
+ from jaclang.settings import settings
10
+ from jaclang.compiler.passes.main import PreDynamoPass
11
+
12
+ class PreDynamoPassTests(TestCase):
13
+ """Test pass module."""
14
+
15
+ TargetPass = PreDynamoPass
16
+
17
+ def setUp(self) -> None:
18
+ """Set up test."""
19
+ return super().setUp()
20
+
21
+ def test_predynamo_where_assign(self) -> None:
22
+ """Test torch.where transformation."""
23
+ captured_output = io.StringIO()
24
+ sys.stdout = captured_output
25
+ os.environ["JAC_PREDYNAMO_PASS"] = "True"
26
+ settings.load_env_vars()
27
+ code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_assign.jac"))
28
+ sys.stdout = sys.__stdout__
29
+ self.assertIn("torch.where", code_gen.unparse())
30
+ os.environ["JAC_PREDYNAMO_PASS"] = "false"
31
+ settings.load_env_vars()
32
+
33
+ def test_predynamo_where_return(self) -> None:
34
+ """Test torch.where transformation."""
35
+ captured_output = io.StringIO()
36
+ sys.stdout = captured_output
37
+ os.environ["JAC_PREDYNAMO_PASS"] = "True"
38
+ settings.load_env_vars()
39
+ code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_return.jac"))
40
+ sys.stdout = sys.__stdout__
41
+ self.assertIn("torch.where", code_gen.unparse())
42
+ os.environ["JAC_PREDYNAMO_PASS"] = "false"
43
+ settings.load_env_vars()
44
+
45
+ def test_predynamo_fix3(self) -> None:
46
+ """Test torch.where transformation."""
47
+ captured_output = io.StringIO()
48
+ sys.stdout = captured_output
49
+ os.environ["JAC_PREDYNAMO_PASS"] = "True"
50
+ settings.load_env_vars()
51
+ code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_fix3.jac"))
52
+ sys.stdout = sys.__stdout__
53
+ unparsed_code = code_gen.unparse()
54
+ self.assertIn("__inv_freq = torch.where(", unparsed_code)
55
+ self.assertIn("self.register_buffer('inv_freq', __inv_freq, persistent=False);", unparsed_code)
56
+ os.environ["JAC_PREDYNAMO_PASS"] = "false"
57
+ settings.load_env_vars()
@@ -9,80 +9,34 @@ Reference:
9
9
  craizy_type_expr branch: type_checker_pass.py
10
10
  """
11
11
 
12
- import ast as py_ast
13
- import os
14
-
15
12
  import jaclang.compiler.unitree as uni
16
13
  from jaclang.compiler.passes import UniPass
17
- from jaclang.compiler.type_system.type_evaluator import TypeEvaluator
18
- from jaclang.runtimelib.utils import read_file_with_encoding
19
-
20
- from .pyast_load_pass import PyastBuildPass
21
- from .sym_tab_build_pass import SymTabBuildPass
14
+ from jaclang.compiler.type_system import types as jtypes
22
15
 
23
16
 
24
17
  class TypeCheckPass(UniPass):
25
18
  """Type checker pass for JacLang."""
26
19
 
27
- # NOTE: This is done in the binder pass of pyright, however I'm doing this
28
- # here, cause this will be the entry point of the type checker and we're not
29
- # relying on the binder pass at the moment and we can go back to binder pass
30
- # in the future if we needed it.
31
- _BUILTINS_STUB_FILE_PATH = os.path.join(
32
- os.path.dirname(__file__),
33
- "../../../vendor/typeshed/stdlib/builtins.pyi",
34
- )
35
-
36
- # Cache the builtins module once it parsed.
37
- _BUILTINS_MODULE: uni.Module | None = None
38
-
39
20
  def before_pass(self) -> None:
40
21
  """Initialize the checker pass."""
41
- self._load_builtins_stub_module()
22
+ self.evaluator = self.prog.get_type_evaluator()
23
+ self.evaluator.diagnostic_callback = self._add_diagnostic
42
24
  self._insert_builtin_symbols()
43
25
 
44
- assert TypeCheckPass._BUILTINS_MODULE is not None
45
- self.evaluator = TypeEvaluator(
46
- builtins_module=TypeCheckPass._BUILTINS_MODULE,
47
- program=self.prog,
48
- )
26
+ def _add_diagnostic(self, node: uni.UniNode, message: str, warning: bool) -> None:
27
+ """Add a diagnostic message to the pass."""
28
+ if warning:
29
+ self.log_warning(message, node)
30
+ else:
31
+ self.log_error(message, node)
49
32
 
50
33
  # --------------------------------------------------------------------------
51
34
  # Internal helper functions
52
35
  # --------------------------------------------------------------------------
53
36
 
54
- def _binding_builtins(self) -> bool:
55
- """Return true if we're binding the builtins stub file."""
56
- return self.ir_in == TypeCheckPass._BUILTINS_MODULE
57
-
58
- def _load_builtins_stub_module(self) -> None:
59
- """Return the builtins stub module.
60
-
61
- This will parse and cache the stub file and return the cached module on
62
- subsequent calls.
63
- """
64
- if self._binding_builtins() or TypeCheckPass._BUILTINS_MODULE is not None:
65
- return
66
-
67
- if not os.path.exists(TypeCheckPass._BUILTINS_STUB_FILE_PATH):
68
- raise FileNotFoundError(
69
- f"Builtins stub file not found at {TypeCheckPass._BUILTINS_STUB_FILE_PATH}"
70
- )
71
-
72
- file_content = read_file_with_encoding(TypeCheckPass._BUILTINS_STUB_FILE_PATH)
73
- uni_source = uni.Source(file_content, TypeCheckPass._BUILTINS_STUB_FILE_PATH)
74
- mod = PyastBuildPass(
75
- ir_in=uni.PythonModuleAst(
76
- py_ast.parse(file_content),
77
- orig_src=uni_source,
78
- ),
79
- prog=self.prog,
80
- ).ir_out
81
- SymTabBuildPass(ir_in=mod, prog=self.prog)
82
- TypeCheckPass._BUILTINS_MODULE = mod
83
-
84
37
  def _insert_builtin_symbols(self) -> None:
85
- if self._binding_builtins():
38
+ # Don't insert builtin symbols into the builtin module itself.
39
+ if self.ir_in == self.evaluator.builtins_module:
86
40
  return
87
41
 
88
42
  # TODO: Insert these symbols.
@@ -92,20 +46,34 @@ class TypeCheckPass(UniPass):
92
46
  # '__name__', '__loader__', '__package__', '__spec__', '__path__',
93
47
  # '__file__', '__cached__', '__dict__', '__annotations__',
94
48
  # '__builtins__', '__doc__',
95
- assert (
96
- TypeCheckPass._BUILTINS_MODULE is not None
97
- ), "Builtins module is not loaded"
98
49
  if self.ir_in.parent_scope is not None:
99
50
  self.log_info("Builtins module is already bound, skipping.")
100
51
  return
101
52
  # Review: If we ever assume a module cannot have a parent scope, this will
102
53
  # break that contract.
103
- self.ir_in.parent_scope = TypeCheckPass._BUILTINS_MODULE
54
+ self.ir_in.parent_scope = self.evaluator.builtins_module
104
55
 
105
56
  # --------------------------------------------------------------------------
106
57
  # Ast walker hooks
107
58
  # --------------------------------------------------------------------------
108
59
 
60
+ def enter_ability(self, node: uni.Ability) -> None:
61
+ """Enter an ability node."""
62
+ # If the node has @staticmethod decorator, mark it as static method.
63
+ # this is needed since ast raised from python does not have this info.
64
+ for decor in node.decorators or []:
65
+ ty = self.evaluator.get_type_of_expression(decor)
66
+ if isinstance(ty, jtypes.ClassType) and ty.is_builtin("staticmethod"):
67
+ node.is_static = True
68
+ break
69
+
70
+ def exit_import(self, node: uni.Import) -> None:
71
+ """Exit an import node."""
72
+ if node.from_loc:
73
+ for item in node.items:
74
+ if isinstance(item, uni.ModuleItem):
75
+ self.evaluator.get_type_of_module_item(item)
76
+
109
77
  def exit_assignment(self, node: uni.Assignment) -> None:
110
78
  """Pyright: Checker.visitAssignment(node: AssignmentNode): boolean."""
111
79
  # TODO: In pyright this logic is present at evaluateTypesForAssignmentStatement
@@ -126,3 +94,10 @@ class TypeCheckPass(UniPass):
126
94
  def exit_atom_trailer(self, node: uni.AtomTrailer) -> None:
127
95
  """Handle the atom trailer node."""
128
96
  self.evaluator.get_type_of_expression(node)
97
+
98
+ def exit_func_call(self, node: uni.FuncCall) -> None:
99
+ """Handle the function call node."""
100
+ # TODO:
101
+ # 1. Function Existence & Callable Validation
102
+ # 2. Argument Matching(count, types, names)
103
+ self.evaluator.get_type_of_expression(node)