jaclang 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jaclang might be problematic. Click here for more details.
- jaclang/cli/cli.md +3 -3
- jaclang/cli/cli.py +37 -37
- jaclang/cli/cmdreg.py +45 -140
- jaclang/compiler/constant.py +0 -1
- jaclang/compiler/jac.lark +3 -6
- jaclang/compiler/larkparse/jac_parser.py +2 -2
- jaclang/compiler/parser.py +213 -34
- jaclang/compiler/passes/main/__init__.py +2 -4
- jaclang/compiler/passes/main/def_use_pass.py +0 -4
- jaclang/compiler/passes/main/predynamo_pass.py +221 -0
- jaclang/compiler/passes/main/pyast_gen_pass.py +83 -55
- jaclang/compiler/passes/main/pyast_load_pass.py +66 -40
- jaclang/compiler/passes/main/sym_tab_build_pass.py +1 -1
- jaclang/compiler/passes/main/tests/fixtures/checker/import_sym.jac +2 -0
- jaclang/compiler/passes/main/tests/fixtures/checker/import_sym_test.jac +6 -0
- jaclang/compiler/passes/main/tests/fixtures/checker/imported_sym.jac +5 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_arg_param_match.jac +37 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +18 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_binary_op.jac +21 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_call_expr_class.jac +12 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_cat_is_animal.jac +18 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_cyclic_symbol.jac +4 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_expr_call.jac +9 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_float.jac +7 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_import_missing_module.jac +13 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_magic_call.jac +17 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_mod_path.jac +8 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_param_types.jac +11 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_self_type.jac +9 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_sym_inherit.jac +42 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_fix3.jac +43 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_where_assign.jac +13 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_where_return.jac +11 -0
- jaclang/compiler/passes/main/tests/test_checker_pass.py +265 -0
- jaclang/compiler/passes/main/tests/test_predynamo_pass.py +57 -0
- jaclang/compiler/passes/main/type_checker_pass.py +36 -61
- jaclang/compiler/passes/tool/doc_ir_gen_pass.py +204 -44
- jaclang/compiler/passes/tool/jac_formatter_pass.py +119 -69
- jaclang/compiler/passes/tool/tests/fixtures/corelib_fmt.jac +3 -3
- jaclang/compiler/passes/tool/tests/fixtures/general_format_checks/triple_quoted_string.jac +4 -5
- jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +171 -11
- jaclang/compiler/passes/transform.py +12 -8
- jaclang/compiler/program.py +14 -6
- jaclang/compiler/tests/fixtures/jac_import_py_files.py +4 -0
- jaclang/compiler/tests/fixtures/jac_module.jac +3 -0
- jaclang/compiler/tests/fixtures/multiple_syntax_errors.jac +10 -0
- jaclang/compiler/tests/fixtures/python_module.py +1 -0
- jaclang/compiler/tests/test_importer.py +39 -0
- jaclang/compiler/tests/test_parser.py +49 -0
- jaclang/compiler/type_system/operations.py +104 -0
- jaclang/compiler/type_system/type_evaluator.py +470 -47
- jaclang/compiler/type_system/type_utils.py +246 -0
- jaclang/compiler/type_system/types.py +58 -2
- jaclang/compiler/unitree.py +79 -94
- jaclang/langserve/engine.jac +253 -230
- jaclang/langserve/server.jac +46 -15
- jaclang/langserve/tests/fixtures/circle.jac +3 -3
- jaclang/langserve/tests/fixtures/circle_err.jac +3 -3
- jaclang/langserve/tests/fixtures/circle_pure.test.jac +3 -3
- jaclang/langserve/tests/fixtures/completion_test_err.jac +10 -0
- jaclang/langserve/tests/server_test/circle_template.jac +80 -0
- jaclang/langserve/tests/server_test/glob_template.jac +4 -0
- jaclang/langserve/tests/server_test/test_lang_serve.py +154 -312
- jaclang/langserve/tests/server_test/utils.py +153 -116
- jaclang/langserve/tests/test_dev_server.py +1 -1
- jaclang/langserve/tests/test_server.py +30 -86
- jaclang/langserve/utils.jac +56 -63
- jaclang/runtimelib/machine.py +7 -0
- jaclang/runtimelib/meta_importer.py +27 -1
- jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
- jaclang/runtimelib/tests/fixtures/savable_object.jac +2 -2
- jaclang/settings.py +18 -14
- jaclang/tests/fixtures/abc_check.jac +3 -3
- jaclang/tests/fixtures/arch_rel_import_creation.jac +12 -12
- jaclang/tests/fixtures/chandra_bugs2.jac +3 -3
- jaclang/tests/fixtures/create_dynamic_archetype.jac +13 -13
- jaclang/tests/fixtures/jac_run_py_bugs.py +18 -0
- jaclang/tests/fixtures/jac_run_py_import.py +13 -0
- jaclang/tests/fixtures/lambda_arg_annotation.jac +15 -0
- jaclang/tests/fixtures/lambda_self.jac +18 -0
- jaclang/tests/fixtures/maxfail_run_test.jac +4 -4
- jaclang/tests/fixtures/params/param_syntax_err.jac +9 -0
- jaclang/tests/fixtures/params/test_complex_params.jac +42 -0
- jaclang/tests/fixtures/params/test_failing_kwonly.jac +207 -0
- jaclang/tests/fixtures/params/test_failing_posonly.jac +116 -0
- jaclang/tests/fixtures/params/test_failing_varargs.jac +300 -0
- jaclang/tests/fixtures/params/test_kwonly_params.jac +29 -0
- jaclang/tests/fixtures/py2jac_params.py +8 -0
- jaclang/tests/fixtures/run_test.jac +4 -4
- jaclang/tests/test_cli.py +103 -18
- jaclang/tests/test_language.py +74 -16
- jaclang/utils/helpers.py +47 -2
- jaclang/utils/module_resolver.py +11 -1
- jaclang/utils/test.py +8 -0
- jaclang/utils/treeprinter.py +0 -18
- {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/METADATA +3 -3
- {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/RECORD +99 -62
- {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/WHEEL +1 -1
- jaclang/compiler/passes/main/inheritance_pass.py +0 -131
- jaclang/langserve/dev_engine.jac +0 -645
- jaclang/langserve/dev_server.jac +0 -201
- jaclang/langserve/tests/server_test/code_test.py +0 -0
- {jaclang-0.8.6.dist-info → jaclang-0.8.8.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
|
|
2
|
+
# -----------------------------------------------------------------------------
|
|
3
|
+
# Simple Inheritance
|
|
4
|
+
# -----------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
node Parent {
|
|
7
|
+
has val: int = 42;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
node Child (Parent) {
|
|
11
|
+
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
with entry {
|
|
16
|
+
c = Child();
|
|
17
|
+
c.val = 42; # <-- Ok
|
|
18
|
+
c.val = "str"; # <-- Error
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# -----------------------------------------------------------------------------
|
|
23
|
+
# A Complex Inheritance
|
|
24
|
+
# -----------------------------------------------------------------------------
|
|
25
|
+
|
|
26
|
+
node Animal {
|
|
27
|
+
has name: str;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
node Cat(Animal) {
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
node Lion(Cat) {
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
with entry {
|
|
38
|
+
l = Lion();
|
|
39
|
+
|
|
40
|
+
l.name = "Simba"; # <-- Ok
|
|
41
|
+
l.name = 42; # <-- Error
|
|
42
|
+
}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import torch;
|
|
2
|
+
class Cfg {
|
|
3
|
+
def __init__(self: Cfg, max_position_embeddings: Any = 8, has_original: Any = False) {
|
|
4
|
+
self.max_position_embeddings = max_position_embeddings;
|
|
5
|
+
if has_original {
|
|
6
|
+
self.original_max_position_embeddings = (max_position_embeddings // 2);
|
|
7
|
+
}
|
|
8
|
+
}
|
|
9
|
+
}
|
|
10
|
+
class ToyModel(torch.nn.Module) {
|
|
11
|
+
def __init__(self: ToyModel, cfg: Cfg) {
|
|
12
|
+
super.init();
|
|
13
|
+
self.config = cfg;
|
|
14
|
+
self.long_inv_freq = torch.tensor([10.0, 20.0, 30.0]);
|
|
15
|
+
self.original_inv_freq = torch.tensor([1.0, 2.0, 3.0]);
|
|
16
|
+
}
|
|
17
|
+
def rope_init_fn(self: ToyModel, cfg: Any, device: Any, seq_len: int) {
|
|
18
|
+
return (torch.arange(1, 4, dtype=torch.float32, device=device), None);
|
|
19
|
+
}
|
|
20
|
+
def _longrope_frequency_update(
|
|
21
|
+
self: ToyModel,
|
|
22
|
+
position_ids: Any,
|
|
23
|
+
device: Any = 'cpu'
|
|
24
|
+
) {
|
|
25
|
+
seq_len = (torch.max(position_ids) + 1);
|
|
26
|
+
if hasattr(self.config, 'original_max_position_embeddings') {
|
|
27
|
+
original_max_position_embeddings = self.config.original_max_position_embeddings;
|
|
28
|
+
} else {
|
|
29
|
+
original_max_position_embeddings = self.config.max_position_embeddings;
|
|
30
|
+
}
|
|
31
|
+
if (seq_len > original_max_position_embeddings) {
|
|
32
|
+
self.register_buffer('inv_freq', self.long_inv_freq, persistent=False);
|
|
33
|
+
} else {
|
|
34
|
+
self.register_buffer(
|
|
35
|
+
'inv_freq',
|
|
36
|
+
self.original_inv_freq.to(device),
|
|
37
|
+
persistent=False
|
|
38
|
+
);
|
|
39
|
+
}
|
|
40
|
+
return self.inv_freq;
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
@@ -28,6 +28,18 @@ class TypeCheckerPassTests(TestCase):
|
|
|
28
28
|
^^^^^^^^^^^^^^^^^^^^^^
|
|
29
29
|
""", program.errors_had[1].pretty_print())
|
|
30
30
|
|
|
31
|
+
def test_float_types(self) -> None:
|
|
32
|
+
program = JacProgram()
|
|
33
|
+
mod = program.compile(self.fixture_abs_path("checker_float.jac"))
|
|
34
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
35
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
36
|
+
self._assert_error_pretty_found("""
|
|
37
|
+
f: float = pi; # <-- OK
|
|
38
|
+
s: str = pi; # <-- Error
|
|
39
|
+
^^^^^^^^^^^
|
|
40
|
+
""", program.errors_had[0].pretty_print())
|
|
41
|
+
|
|
42
|
+
|
|
31
43
|
def test_infer_type_of_assignment(self) -> None:
|
|
32
44
|
program = JacProgram()
|
|
33
45
|
mod = program.compile(self.fixture_abs_path("infer_type_assignment.jac"))
|
|
@@ -49,6 +61,17 @@ class TypeCheckerPassTests(TestCase):
|
|
|
49
61
|
^^^^^^^^^^^^^^^^^^^
|
|
50
62
|
""", program.errors_had[0].pretty_print())
|
|
51
63
|
|
|
64
|
+
def test_imported_sym(self) -> None:
|
|
65
|
+
program = JacProgram()
|
|
66
|
+
mod = program.compile(self.fixture_abs_path("checker/import_sym_test.jac"))
|
|
67
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
68
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
69
|
+
self._assert_error_pretty_found("""
|
|
70
|
+
a: str = foo(); # <-- Ok
|
|
71
|
+
b: int = foo(); # <-- Error
|
|
72
|
+
^^^^^^^^^^^^^^
|
|
73
|
+
""", program.errors_had[0].pretty_print())
|
|
74
|
+
|
|
52
75
|
def test_member_access_type_infered(self) -> None:
|
|
53
76
|
program = JacProgram()
|
|
54
77
|
mod = program.compile(self.fixture_abs_path("member_access_type_inferred.jac"))
|
|
@@ -59,6 +82,22 @@ class TypeCheckerPassTests(TestCase):
|
|
|
59
82
|
^^^^^^^^^
|
|
60
83
|
""", program.errors_had[0].pretty_print())
|
|
61
84
|
|
|
85
|
+
def test_inherited_symbol(self) -> None:
|
|
86
|
+
program = JacProgram()
|
|
87
|
+
mod = program.compile(self.fixture_abs_path("checker_sym_inherit.jac"))
|
|
88
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
89
|
+
self.assertEqual(len(program.errors_had), 2)
|
|
90
|
+
self._assert_error_pretty_found("""
|
|
91
|
+
c.val = 42; # <-- Ok
|
|
92
|
+
c.val = "str"; # <-- Error
|
|
93
|
+
^^^^^^^^^^^^^
|
|
94
|
+
""", program.errors_had[0].pretty_print())
|
|
95
|
+
self._assert_error_pretty_found("""
|
|
96
|
+
l.name = "Simba"; # <-- Ok
|
|
97
|
+
l.name = 42; # <-- Error
|
|
98
|
+
^^^^^^^^^^^
|
|
99
|
+
""", program.errors_had[1].pretty_print())
|
|
100
|
+
|
|
62
101
|
def test_import_symbol_type_infer(self) -> None:
|
|
63
102
|
program = JacProgram()
|
|
64
103
|
mod = program.compile(self.fixture_abs_path("import_symbol_type_infer.jac"))
|
|
@@ -81,6 +120,232 @@ class TypeCheckerPassTests(TestCase):
|
|
|
81
120
|
^^^^^^^^^^^^^^
|
|
82
121
|
""", program.errors_had[0].pretty_print())
|
|
83
122
|
|
|
123
|
+
def test_call_expr(self) -> None:
|
|
124
|
+
path = self.fixture_abs_path("checker_expr_call.jac")
|
|
125
|
+
program = JacProgram()
|
|
126
|
+
mod = program.compile(path)
|
|
127
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
128
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
129
|
+
self._assert_error_pretty_found("""
|
|
130
|
+
s: str = foo();
|
|
131
|
+
^^^^^^^^^^^^^^
|
|
132
|
+
""", program.errors_had[0].pretty_print())
|
|
133
|
+
|
|
134
|
+
def test_call_expr_magic(self) -> None:
|
|
135
|
+
path = self.fixture_abs_path("checker_magic_call.jac")
|
|
136
|
+
program = JacProgram()
|
|
137
|
+
mod = program.compile(path)
|
|
138
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
139
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
140
|
+
self._assert_error_pretty_found("""
|
|
141
|
+
b: Bar = fn()(); # <-- Ok
|
|
142
|
+
f: Foo = fn()(); # <-- Error
|
|
143
|
+
^^^^^^^^^^^^^^^^
|
|
144
|
+
""", program.errors_had[0].pretty_print())
|
|
145
|
+
|
|
146
|
+
def test_arity(self) -> None:
|
|
147
|
+
path = self.fixture_abs_path("checker_arity.jac")
|
|
148
|
+
program = JacProgram()
|
|
149
|
+
mod = program.compile(path)
|
|
150
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
151
|
+
self.assertEqual(len(program.errors_had), 3)
|
|
152
|
+
self._assert_error_pretty_found("""
|
|
153
|
+
f.first_is_self(f); # <-- Error
|
|
154
|
+
^
|
|
155
|
+
""", program.errors_had[0].pretty_print())
|
|
156
|
+
self._assert_error_pretty_found("""
|
|
157
|
+
f.with_default_args(1, 2, 3); # <-- Error
|
|
158
|
+
^
|
|
159
|
+
""", program.errors_had[1].pretty_print())
|
|
160
|
+
self._assert_error_pretty_found("""
|
|
161
|
+
f.with_default_args(); # <-- Error
|
|
162
|
+
^^^^^^^^^^^^^^^^^^^^^
|
|
163
|
+
""", program.errors_had[2].pretty_print())
|
|
164
|
+
|
|
165
|
+
def test_param_types(self) -> None:
|
|
166
|
+
path = self.fixture_abs_path("checker_param_types.jac")
|
|
167
|
+
program = JacProgram()
|
|
168
|
+
mod = program.compile(path)
|
|
169
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
170
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
171
|
+
self._assert_error_pretty_found("""
|
|
172
|
+
foo(A()); # <-- Ok
|
|
173
|
+
foo(B()); # <-- Error
|
|
174
|
+
^^^
|
|
175
|
+
""", program.errors_had[0].pretty_print())
|
|
176
|
+
|
|
177
|
+
def test_param_arg_match(self) -> None:
|
|
178
|
+
path = self.fixture_abs_path("checker_param_types.jac")
|
|
179
|
+
program = JacProgram()
|
|
180
|
+
path = self.fixture_abs_path("checker_arg_param_match.jac")
|
|
181
|
+
mod = program.compile(path)
|
|
182
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
183
|
+
self.assertEqual(len(program.errors_had), 13)
|
|
184
|
+
|
|
185
|
+
expected_errors = [
|
|
186
|
+
"""
|
|
187
|
+
Not all required parameters were provided in the function call: 'a'
|
|
188
|
+
f = Foo();
|
|
189
|
+
f.bar();
|
|
190
|
+
^^^^^^^
|
|
191
|
+
""",
|
|
192
|
+
"""
|
|
193
|
+
Too many positional arguments
|
|
194
|
+
f.bar();
|
|
195
|
+
f.bar(1);
|
|
196
|
+
f.bar(1, 2);
|
|
197
|
+
^
|
|
198
|
+
""",
|
|
199
|
+
"""
|
|
200
|
+
Not all required parameters were provided in the function call: 'self', 'a'
|
|
201
|
+
f.bar(1, 2);
|
|
202
|
+
f.baz();
|
|
203
|
+
^^^^^^^
|
|
204
|
+
""",
|
|
205
|
+
"""
|
|
206
|
+
Not all required parameters were provided in the function call: 'a'
|
|
207
|
+
f.baz();
|
|
208
|
+
f.baz(1);
|
|
209
|
+
^^^^^^^^
|
|
210
|
+
""",
|
|
211
|
+
"""
|
|
212
|
+
Not all required parameters were provided in the function call: 'f'
|
|
213
|
+
foo(1, 2, d=3, e=4, f=5, c=4); # order does not matter for named
|
|
214
|
+
foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
|
|
215
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
216
|
+
""",
|
|
217
|
+
"""
|
|
218
|
+
Positional only parameter 'b' cannot be matched with a named argument
|
|
219
|
+
foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
|
|
220
|
+
foo(1, b=2, c=3, d=4, e=5, f=6); # b is positional only
|
|
221
|
+
^^^
|
|
222
|
+
""",
|
|
223
|
+
"""
|
|
224
|
+
Too many positional arguments
|
|
225
|
+
bar(1, 2, 3, 4, 5, f=6);
|
|
226
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
227
|
+
^
|
|
228
|
+
""",
|
|
229
|
+
"""
|
|
230
|
+
Too many positional arguments
|
|
231
|
+
bar(1, 2, 3, 4, 5, f=6);
|
|
232
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
233
|
+
^
|
|
234
|
+
""",
|
|
235
|
+
"""
|
|
236
|
+
Too many positional arguments
|
|
237
|
+
bar(1, 2, 3, 4, 5, f=6);
|
|
238
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
239
|
+
^
|
|
240
|
+
""",
|
|
241
|
+
"""
|
|
242
|
+
Parameter 'c' already matched
|
|
243
|
+
bar(1, 2, 3, 4, 5, f=6);
|
|
244
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
245
|
+
bar(1, 2, 3, 4, 5, 6, c=3); # already matched
|
|
246
|
+
^^^
|
|
247
|
+
""",
|
|
248
|
+
"""
|
|
249
|
+
Named argument 'h' does not match any parameter
|
|
250
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
251
|
+
bar(1, 2, 3, 4, 5, 6, c=3); # already matched
|
|
252
|
+
bar(1, 2, 3, 4, 5, 6, h=1); # h is not matched
|
|
253
|
+
^^^
|
|
254
|
+
""",
|
|
255
|
+
"""
|
|
256
|
+
Too many positional arguments
|
|
257
|
+
baz(a=1, b=2);
|
|
258
|
+
baz(1, b=2); # a can be both positional and keyword
|
|
259
|
+
baz(1, 2); # 'b' can only be keyword arg
|
|
260
|
+
^
|
|
261
|
+
""",
|
|
262
|
+
"""
|
|
263
|
+
Not all required parameters were provided in the function call: 'b'
|
|
264
|
+
baz(a=1, b=2);
|
|
265
|
+
baz(1, b=2); # a can be both positional and keyword
|
|
266
|
+
baz(1, 2); # 'b' can only be keyword arg
|
|
267
|
+
^^^^^^^^^
|
|
268
|
+
""",
|
|
269
|
+
]
|
|
270
|
+
|
|
271
|
+
for i, expected in enumerate(expected_errors):
|
|
272
|
+
self._assert_error_pretty_found(expected, program.errors_had[i].pretty_print())
|
|
273
|
+
|
|
274
|
+
def test_self_type_inference(self) -> None:
|
|
275
|
+
path = self.fixture_abs_path("checker_self_type.jac")
|
|
276
|
+
program = JacProgram()
|
|
277
|
+
mod = program.compile(path)
|
|
278
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
279
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
280
|
+
self._assert_error_pretty_found("""
|
|
281
|
+
x: str = self.i; # <-- Error
|
|
282
|
+
^^^^^^^^^^^^^^^
|
|
283
|
+
""", program.errors_had[0].pretty_print())
|
|
284
|
+
|
|
285
|
+
def test_binary_op(self) -> None:
|
|
286
|
+
program = JacProgram()
|
|
287
|
+
mod = program.compile(self.fixture_abs_path("checker_binary_op.jac"))
|
|
288
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
289
|
+
self.assertEqual(len(program.errors_had), 2)
|
|
290
|
+
self._assert_error_pretty_found("""
|
|
291
|
+
r2: A = a + a; # <-- Error
|
|
292
|
+
^^^^^^^^^^^^^
|
|
293
|
+
""", program.errors_had[0].pretty_print())
|
|
294
|
+
self._assert_error_pretty_found("""
|
|
295
|
+
r4: str = (a+a) * B(); # <-- Error
|
|
296
|
+
^^^^^^^^^^^^^^^^^^^^^
|
|
297
|
+
""", program.errors_had[1].pretty_print())
|
|
298
|
+
|
|
299
|
+
def test_checker_call_expr_class(self) -> None:
|
|
300
|
+
path = self.fixture_abs_path("checker_call_expr_class.jac")
|
|
301
|
+
program = JacProgram()
|
|
302
|
+
mod = program.compile(path)
|
|
303
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
304
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
305
|
+
self._assert_error_pretty_found("""
|
|
306
|
+
inst.i = 'str'; # <-- Error
|
|
307
|
+
^^^^^^^^^^^^^^
|
|
308
|
+
""", program.errors_had[0].pretty_print())
|
|
309
|
+
|
|
310
|
+
def test_checker_mod_path(self) -> None:
|
|
311
|
+
path = self.fixture_abs_path("checker_mod_path.jac")
|
|
312
|
+
program = JacProgram()
|
|
313
|
+
mod = program.compile(path)
|
|
314
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
315
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
316
|
+
self._assert_error_pretty_found("""
|
|
317
|
+
a:int = uni.Module; # <-- Error
|
|
318
|
+
^^^^^^^^^^^^^^
|
|
319
|
+
""", program.errors_had[0].pretty_print())
|
|
320
|
+
|
|
321
|
+
def test_checker_cat_is_animal(self) -> None:
|
|
322
|
+
path = self.fixture_abs_path("checker_cat_is_animal.jac")
|
|
323
|
+
program = JacProgram()
|
|
324
|
+
mod = program.compile(path)
|
|
325
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
326
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
327
|
+
self._assert_error_pretty_found("""
|
|
328
|
+
animal_func(cat); # <-- Ok
|
|
329
|
+
animal_func(lion); # <-- Ok
|
|
330
|
+
animal_func(not_animal); # <-- Error
|
|
331
|
+
^^^^^^^^^^
|
|
332
|
+
""", program.errors_had[0].pretty_print())
|
|
333
|
+
|
|
334
|
+
def test_checker_import_missing_module(self) -> None:
|
|
335
|
+
path = self.fixture_abs_path("checker_import_missing_module.jac")
|
|
336
|
+
program = JacProgram()
|
|
337
|
+
mod = program.compile(path)
|
|
338
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
339
|
+
self.assertEqual(len(program.errors_had), 0)
|
|
340
|
+
|
|
341
|
+
def test_cyclic_symbol(self) -> None:
|
|
342
|
+
path = self.fixture_abs_path("checker_cyclic_symbol.jac")
|
|
343
|
+
program = JacProgram()
|
|
344
|
+
mod = program.compile(path)
|
|
345
|
+
# This will result in a stack overflow if not handled properly.
|
|
346
|
+
# So the fact that it has 0 errors means it passed.
|
|
347
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
348
|
+
self.assertEqual(len(program.errors_had), 0)
|
|
84
349
|
|
|
85
350
|
def _assert_error_pretty_found(self, needle: str, haystack: str) -> None:
|
|
86
351
|
for line in [line.strip() for line in needle.splitlines() if line.strip()]:
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
"""Test ast build pass module."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
from jaclang.compiler.program import JacProgram
|
|
8
|
+
from jaclang.utils.test import TestCase
|
|
9
|
+
from jaclang.settings import settings
|
|
10
|
+
from jaclang.compiler.passes.main import PreDynamoPass
|
|
11
|
+
|
|
12
|
+
class PreDynamoPassTests(TestCase):
|
|
13
|
+
"""Test pass module."""
|
|
14
|
+
|
|
15
|
+
TargetPass = PreDynamoPass
|
|
16
|
+
|
|
17
|
+
def setUp(self) -> None:
|
|
18
|
+
"""Set up test."""
|
|
19
|
+
return super().setUp()
|
|
20
|
+
|
|
21
|
+
def test_predynamo_where_assign(self) -> None:
|
|
22
|
+
"""Test torch.where transformation."""
|
|
23
|
+
captured_output = io.StringIO()
|
|
24
|
+
sys.stdout = captured_output
|
|
25
|
+
os.environ["JAC_PREDYNAMO_PASS"] = "True"
|
|
26
|
+
settings.load_env_vars()
|
|
27
|
+
code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_assign.jac"))
|
|
28
|
+
sys.stdout = sys.__stdout__
|
|
29
|
+
self.assertIn("torch.where", code_gen.unparse())
|
|
30
|
+
os.environ["JAC_PREDYNAMO_PASS"] = "false"
|
|
31
|
+
settings.load_env_vars()
|
|
32
|
+
|
|
33
|
+
def test_predynamo_where_return(self) -> None:
|
|
34
|
+
"""Test torch.where transformation."""
|
|
35
|
+
captured_output = io.StringIO()
|
|
36
|
+
sys.stdout = captured_output
|
|
37
|
+
os.environ["JAC_PREDYNAMO_PASS"] = "True"
|
|
38
|
+
settings.load_env_vars()
|
|
39
|
+
code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_return.jac"))
|
|
40
|
+
sys.stdout = sys.__stdout__
|
|
41
|
+
self.assertIn("torch.where", code_gen.unparse())
|
|
42
|
+
os.environ["JAC_PREDYNAMO_PASS"] = "false"
|
|
43
|
+
settings.load_env_vars()
|
|
44
|
+
|
|
45
|
+
def test_predynamo_fix3(self) -> None:
|
|
46
|
+
"""Test torch.where transformation."""
|
|
47
|
+
captured_output = io.StringIO()
|
|
48
|
+
sys.stdout = captured_output
|
|
49
|
+
os.environ["JAC_PREDYNAMO_PASS"] = "True"
|
|
50
|
+
settings.load_env_vars()
|
|
51
|
+
code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_fix3.jac"))
|
|
52
|
+
sys.stdout = sys.__stdout__
|
|
53
|
+
unparsed_code = code_gen.unparse()
|
|
54
|
+
self.assertIn("__inv_freq = torch.where(", unparsed_code)
|
|
55
|
+
self.assertIn("self.register_buffer('inv_freq', __inv_freq, persistent=False);", unparsed_code)
|
|
56
|
+
os.environ["JAC_PREDYNAMO_PASS"] = "false"
|
|
57
|
+
settings.load_env_vars()
|
|
@@ -9,80 +9,34 @@ Reference:
|
|
|
9
9
|
craizy_type_expr branch: type_checker_pass.py
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
import ast as py_ast
|
|
13
|
-
import os
|
|
14
|
-
|
|
15
12
|
import jaclang.compiler.unitree as uni
|
|
16
13
|
from jaclang.compiler.passes import UniPass
|
|
17
|
-
from jaclang.compiler.type_system
|
|
18
|
-
from jaclang.runtimelib.utils import read_file_with_encoding
|
|
19
|
-
|
|
20
|
-
from .pyast_load_pass import PyastBuildPass
|
|
21
|
-
from .sym_tab_build_pass import SymTabBuildPass
|
|
14
|
+
from jaclang.compiler.type_system import types as jtypes
|
|
22
15
|
|
|
23
16
|
|
|
24
17
|
class TypeCheckPass(UniPass):
|
|
25
18
|
"""Type checker pass for JacLang."""
|
|
26
19
|
|
|
27
|
-
# NOTE: This is done in the binder pass of pyright, however I'm doing this
|
|
28
|
-
# here, cause this will be the entry point of the type checker and we're not
|
|
29
|
-
# relying on the binder pass at the moment and we can go back to binder pass
|
|
30
|
-
# in the future if we needed it.
|
|
31
|
-
_BUILTINS_STUB_FILE_PATH = os.path.join(
|
|
32
|
-
os.path.dirname(__file__),
|
|
33
|
-
"../../../vendor/typeshed/stdlib/builtins.pyi",
|
|
34
|
-
)
|
|
35
|
-
|
|
36
|
-
# Cache the builtins module once it parsed.
|
|
37
|
-
_BUILTINS_MODULE: uni.Module | None = None
|
|
38
|
-
|
|
39
20
|
def before_pass(self) -> None:
|
|
40
21
|
"""Initialize the checker pass."""
|
|
41
|
-
self.
|
|
22
|
+
self.evaluator = self.prog.get_type_evaluator()
|
|
23
|
+
self.evaluator.diagnostic_callback = self._add_diagnostic
|
|
42
24
|
self._insert_builtin_symbols()
|
|
43
25
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
26
|
+
def _add_diagnostic(self, node: uni.UniNode, message: str, warning: bool) -> None:
|
|
27
|
+
"""Add a diagnostic message to the pass."""
|
|
28
|
+
if warning:
|
|
29
|
+
self.log_warning(message, node)
|
|
30
|
+
else:
|
|
31
|
+
self.log_error(message, node)
|
|
49
32
|
|
|
50
33
|
# --------------------------------------------------------------------------
|
|
51
34
|
# Internal helper functions
|
|
52
35
|
# --------------------------------------------------------------------------
|
|
53
36
|
|
|
54
|
-
def _binding_builtins(self) -> bool:
|
|
55
|
-
"""Return true if we're binding the builtins stub file."""
|
|
56
|
-
return self.ir_in == TypeCheckPass._BUILTINS_MODULE
|
|
57
|
-
|
|
58
|
-
def _load_builtins_stub_module(self) -> None:
|
|
59
|
-
"""Return the builtins stub module.
|
|
60
|
-
|
|
61
|
-
This will parse and cache the stub file and return the cached module on
|
|
62
|
-
subsequent calls.
|
|
63
|
-
"""
|
|
64
|
-
if self._binding_builtins() or TypeCheckPass._BUILTINS_MODULE is not None:
|
|
65
|
-
return
|
|
66
|
-
|
|
67
|
-
if not os.path.exists(TypeCheckPass._BUILTINS_STUB_FILE_PATH):
|
|
68
|
-
raise FileNotFoundError(
|
|
69
|
-
f"Builtins stub file not found at {TypeCheckPass._BUILTINS_STUB_FILE_PATH}"
|
|
70
|
-
)
|
|
71
|
-
|
|
72
|
-
file_content = read_file_with_encoding(TypeCheckPass._BUILTINS_STUB_FILE_PATH)
|
|
73
|
-
uni_source = uni.Source(file_content, TypeCheckPass._BUILTINS_STUB_FILE_PATH)
|
|
74
|
-
mod = PyastBuildPass(
|
|
75
|
-
ir_in=uni.PythonModuleAst(
|
|
76
|
-
py_ast.parse(file_content),
|
|
77
|
-
orig_src=uni_source,
|
|
78
|
-
),
|
|
79
|
-
prog=self.prog,
|
|
80
|
-
).ir_out
|
|
81
|
-
SymTabBuildPass(ir_in=mod, prog=self.prog)
|
|
82
|
-
TypeCheckPass._BUILTINS_MODULE = mod
|
|
83
|
-
|
|
84
37
|
def _insert_builtin_symbols(self) -> None:
|
|
85
|
-
|
|
38
|
+
# Don't insert builtin symbols into the builtin module itself.
|
|
39
|
+
if self.ir_in == self.evaluator.builtins_module:
|
|
86
40
|
return
|
|
87
41
|
|
|
88
42
|
# TODO: Insert these symbols.
|
|
@@ -92,20 +46,34 @@ class TypeCheckPass(UniPass):
|
|
|
92
46
|
# '__name__', '__loader__', '__package__', '__spec__', '__path__',
|
|
93
47
|
# '__file__', '__cached__', '__dict__', '__annotations__',
|
|
94
48
|
# '__builtins__', '__doc__',
|
|
95
|
-
assert (
|
|
96
|
-
TypeCheckPass._BUILTINS_MODULE is not None
|
|
97
|
-
), "Builtins module is not loaded"
|
|
98
49
|
if self.ir_in.parent_scope is not None:
|
|
99
50
|
self.log_info("Builtins module is already bound, skipping.")
|
|
100
51
|
return
|
|
101
52
|
# Review: If we ever assume a module cannot have a parent scope, this will
|
|
102
53
|
# break that contract.
|
|
103
|
-
self.ir_in.parent_scope =
|
|
54
|
+
self.ir_in.parent_scope = self.evaluator.builtins_module
|
|
104
55
|
|
|
105
56
|
# --------------------------------------------------------------------------
|
|
106
57
|
# Ast walker hooks
|
|
107
58
|
# --------------------------------------------------------------------------
|
|
108
59
|
|
|
60
|
+
def enter_ability(self, node: uni.Ability) -> None:
|
|
61
|
+
"""Enter an ability node."""
|
|
62
|
+
# If the node has @staticmethod decorator, mark it as static method.
|
|
63
|
+
# this is needed since ast raised from python does not have this info.
|
|
64
|
+
for decor in node.decorators or []:
|
|
65
|
+
ty = self.evaluator.get_type_of_expression(decor)
|
|
66
|
+
if isinstance(ty, jtypes.ClassType) and ty.is_builtin("staticmethod"):
|
|
67
|
+
node.is_static = True
|
|
68
|
+
break
|
|
69
|
+
|
|
70
|
+
def exit_import(self, node: uni.Import) -> None:
|
|
71
|
+
"""Exit an import node."""
|
|
72
|
+
if node.from_loc:
|
|
73
|
+
for item in node.items:
|
|
74
|
+
if isinstance(item, uni.ModuleItem):
|
|
75
|
+
self.evaluator.get_type_of_module_item(item)
|
|
76
|
+
|
|
109
77
|
def exit_assignment(self, node: uni.Assignment) -> None:
|
|
110
78
|
"""Pyright: Checker.visitAssignment(node: AssignmentNode): boolean."""
|
|
111
79
|
# TODO: In pyright this logic is present at evaluateTypesForAssignmentStatement
|
|
@@ -126,3 +94,10 @@ class TypeCheckPass(UniPass):
|
|
|
126
94
|
def exit_atom_trailer(self, node: uni.AtomTrailer) -> None:
|
|
127
95
|
"""Handle the atom trailer node."""
|
|
128
96
|
self.evaluator.get_type_of_expression(node)
|
|
97
|
+
|
|
98
|
+
def exit_func_call(self, node: uni.FuncCall) -> None:
|
|
99
|
+
"""Handle the function call node."""
|
|
100
|
+
# TODO:
|
|
101
|
+
# 1. Function Existence & Callable Validation
|
|
102
|
+
# 2. Argument Matching(count, types, names)
|
|
103
|
+
self.evaluator.get_type_of_expression(node)
|