jaclang 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of jaclang might be problematic. Click here for more details.
- jaclang/cli/cli.py +77 -29
- jaclang/cli/cmdreg.py +44 -0
- jaclang/compiler/constant.py +6 -2
- jaclang/compiler/jac.lark +37 -47
- jaclang/compiler/larkparse/jac_parser.py +2 -2
- jaclang/compiler/parser.py +356 -61
- jaclang/compiler/passes/main/__init__.py +2 -4
- jaclang/compiler/passes/main/def_use_pass.py +1 -4
- jaclang/compiler/passes/main/predynamo_pass.py +221 -0
- jaclang/compiler/passes/main/pyast_gen_pass.py +221 -135
- jaclang/compiler/passes/main/pyast_load_pass.py +54 -20
- jaclang/compiler/passes/main/sym_tab_build_pass.py +1 -1
- jaclang/compiler/passes/main/tests/fixtures/checker/import_sym.jac +2 -0
- jaclang/compiler/passes/main/tests/fixtures/checker/import_sym_test.jac +6 -0
- jaclang/compiler/passes/main/tests/fixtures/checker/imported_sym.jac +5 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_arg_param_match.jac +37 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_arity.jac +18 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_cat_is_animal.jac +18 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_float.jac +7 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_param_types.jac +11 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_self_type.jac +9 -0
- jaclang/compiler/passes/main/tests/fixtures/checker_sym_inherit.jac +42 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_fix3.jac +43 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_where_assign.jac +13 -0
- jaclang/compiler/passes/main/tests/fixtures/predynamo_where_return.jac +11 -0
- jaclang/compiler/passes/main/tests/test_checker_pass.py +190 -0
- jaclang/compiler/passes/main/tests/test_predynamo_pass.py +56 -0
- jaclang/compiler/passes/main/type_checker_pass.py +29 -73
- jaclang/compiler/passes/tool/doc_ir_gen_pass.py +302 -58
- jaclang/compiler/passes/tool/jac_formatter_pass.py +119 -69
- jaclang/compiler/passes/tool/tests/fixtures/corelib_fmt.jac +3 -3
- jaclang/compiler/passes/tool/tests/fixtures/general_format_checks/triple_quoted_string.jac +4 -5
- jaclang/compiler/passes/tool/tests/fixtures/import_fmt.jac +7 -1
- jaclang/compiler/passes/tool/tests/fixtures/tagbreak.jac +276 -10
- jaclang/compiler/passes/transform.py +12 -8
- jaclang/compiler/program.py +19 -7
- jaclang/compiler/tests/fixtures/jac_import_py_files.py +4 -0
- jaclang/compiler/tests/fixtures/jac_module.jac +3 -0
- jaclang/compiler/tests/fixtures/multiple_syntax_errors.jac +10 -0
- jaclang/compiler/tests/fixtures/python_module.py +1 -0
- jaclang/compiler/tests/test_importer.py +39 -0
- jaclang/compiler/tests/test_parser.py +49 -0
- jaclang/compiler/type_system/type_evaluator.jac +959 -0
- jaclang/compiler/type_system/type_utils.py +246 -0
- jaclang/compiler/type_system/types.py +58 -2
- jaclang/compiler/unitree.py +102 -107
- jaclang/langserve/engine.jac +138 -159
- jaclang/langserve/server.jac +25 -1
- jaclang/langserve/tests/fixtures/circle.jac +3 -3
- jaclang/langserve/tests/fixtures/circle_err.jac +3 -3
- jaclang/langserve/tests/fixtures/circle_pure.test.jac +3 -3
- jaclang/langserve/tests/fixtures/completion_test_err.jac +10 -0
- jaclang/langserve/tests/server_test/circle_template.jac +80 -0
- jaclang/langserve/tests/server_test/glob_template.jac +4 -0
- jaclang/langserve/tests/server_test/test_lang_serve.py +154 -309
- jaclang/langserve/tests/server_test/utils.py +153 -116
- jaclang/langserve/tests/test_server.py +21 -84
- jaclang/langserve/utils.jac +12 -15
- jaclang/lib.py +17 -0
- jaclang/runtimelib/archetype.py +25 -25
- jaclang/runtimelib/constructs.py +2 -2
- jaclang/runtimelib/machine.py +63 -46
- jaclang/runtimelib/meta_importer.py +27 -1
- jaclang/runtimelib/tests/fixtures/custom_access_validation.jac +1 -1
- jaclang/runtimelib/tests/fixtures/savable_object.jac +2 -2
- jaclang/settings.py +19 -16
- jaclang/tests/fixtures/abc_check.jac +3 -3
- jaclang/tests/fixtures/arch_rel_import_creation.jac +12 -12
- jaclang/tests/fixtures/attr_pattern_case.jac +18 -0
- jaclang/tests/fixtures/chandra_bugs2.jac +3 -3
- jaclang/tests/fixtures/create_dynamic_archetype.jac +13 -13
- jaclang/tests/fixtures/funccall_genexpr.jac +7 -0
- jaclang/tests/fixtures/funccall_genexpr.py +5 -0
- jaclang/tests/fixtures/maxfail_run_test.jac +4 -4
- jaclang/tests/fixtures/params/param_syntax_err.jac +9 -0
- jaclang/tests/fixtures/params/test_complex_params.jac +42 -0
- jaclang/tests/fixtures/params/test_failing_kwonly.jac +207 -0
- jaclang/tests/fixtures/params/test_failing_posonly.jac +116 -0
- jaclang/tests/fixtures/params/test_failing_varargs.jac +300 -0
- jaclang/tests/fixtures/params/test_kwonly_params.jac +29 -0
- jaclang/tests/fixtures/py2jac_params.py +8 -0
- jaclang/tests/fixtures/run_test.jac +4 -4
- jaclang/tests/test_cli.py +159 -7
- jaclang/tests/test_language.py +213 -38
- jaclang/tests/test_reference.py +3 -1
- jaclang/utils/helpers.py +67 -6
- jaclang/utils/module_resolver.py +10 -0
- jaclang/utils/test.py +8 -0
- jaclang/utils/tests/test_lang_tools.py +4 -15
- jaclang/utils/treeprinter.py +0 -18
- {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/METADATA +1 -2
- {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/RECORD +95 -65
- {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/WHEEL +1 -1
- jaclang/compiler/passes/main/inheritance_pass.py +0 -131
- jaclang/compiler/type_system/type_evaluator.py +0 -560
- jaclang/langserve/dev_engine.jac +0 -645
- jaclang/langserve/dev_server.jac +0 -201
- /jaclang/{langserve/tests/server_test/code_test.py → tests/fixtures/py2jac_empty.py} +0 -0
- {jaclang-0.8.7.dist-info → jaclang-0.8.9.dist-info}/entry_points.txt +0 -0
|
@@ -16,7 +16,7 @@ from __future__ import annotations
|
|
|
16
16
|
|
|
17
17
|
import ast as py_ast
|
|
18
18
|
import os
|
|
19
|
-
from typing import Optional, Sequence, TYPE_CHECKING, TypeAlias, TypeVar
|
|
19
|
+
from typing import Optional, Sequence, TYPE_CHECKING, TypeAlias, TypeVar, cast
|
|
20
20
|
|
|
21
21
|
import jaclang.compiler.unitree as uni
|
|
22
22
|
from jaclang.compiler.constant import Tokens as Tok
|
|
@@ -96,6 +96,8 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
|
|
|
96
96
|
body: list[stmt]
|
|
97
97
|
type_ignores: list[TypeIgnore]
|
|
98
98
|
"""
|
|
99
|
+
if not node.body:
|
|
100
|
+
return uni.Module.make_stub(inject_src=self.ir_in)
|
|
99
101
|
elements: list[uni.UniNode] = [self.convert(i) for i in node.body]
|
|
100
102
|
elements[0] = (
|
|
101
103
|
elements[0].expr
|
|
@@ -2091,6 +2093,7 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
|
|
|
2091
2093
|
"""Process python node.
|
|
2092
2094
|
|
|
2093
2095
|
class arguments(AST):
|
|
2096
|
+
posonlyargs: list[arg]
|
|
2094
2097
|
args: list[arg]
|
|
2095
2098
|
vararg: arg | None
|
|
2096
2099
|
kwonlyargs: list[arg]
|
|
@@ -2098,9 +2101,23 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
|
|
|
2098
2101
|
kwarg: arg | None
|
|
2099
2102
|
defaults: list[expr]
|
|
2100
2103
|
"""
|
|
2101
|
-
|
|
2104
|
+
|
|
2105
|
+
def _apply_kind(params: list, kind: uni.ParamKind) -> list:
|
|
2106
|
+
for param in params:
|
|
2107
|
+
cast(uni.ParamVar, param).param_kind = kind
|
|
2108
|
+
return params
|
|
2109
|
+
|
|
2110
|
+
posonlyargs = _apply_kind(
|
|
2111
|
+
[self.convert(arg) for arg in node.posonlyargs], uni.ParamKind.POSONLY
|
|
2112
|
+
)
|
|
2113
|
+
args = _apply_kind(
|
|
2114
|
+
[self.convert(arg) for arg in node.args], uni.ParamKind.NORMAL
|
|
2115
|
+
)
|
|
2116
|
+
|
|
2102
2117
|
vararg = self.convert(node.vararg) if node.vararg else None
|
|
2118
|
+
|
|
2103
2119
|
if vararg and isinstance(vararg, uni.ParamVar):
|
|
2120
|
+
vararg.param_kind = uni.ParamKind.VARARG
|
|
2104
2121
|
vararg.unpack = uni.Token(
|
|
2105
2122
|
orig_src=self.orig_src,
|
|
2106
2123
|
name=Tok.STAR_MUL,
|
|
@@ -2113,7 +2130,10 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
|
|
|
2113
2130
|
pos_end=0,
|
|
2114
2131
|
)
|
|
2115
2132
|
vararg.add_kids_left([vararg.unpack])
|
|
2116
|
-
|
|
2133
|
+
|
|
2134
|
+
kwonlyargs = _apply_kind(
|
|
2135
|
+
[self.convert(arg) for arg in node.kwonlyargs], uni.ParamKind.KWONLY
|
|
2136
|
+
)
|
|
2117
2137
|
for i in range(len(kwonlyargs)):
|
|
2118
2138
|
kwa = kwonlyargs[i]
|
|
2119
2139
|
kwd = node.kw_defaults[i]
|
|
@@ -2127,6 +2147,7 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
|
|
|
2127
2147
|
kwa.add_kids_right([kwa.value])
|
|
2128
2148
|
kwarg = self.convert(node.kwarg) if node.kwarg else None
|
|
2129
2149
|
if kwarg and isinstance(kwarg, uni.ParamVar):
|
|
2150
|
+
kwarg.param_kind = uni.ParamKind.KWARG
|
|
2130
2151
|
kwarg.unpack = uni.Token(
|
|
2131
2152
|
orig_src=self.orig_src,
|
|
2132
2153
|
name=Tok.STAR_POW,
|
|
@@ -2140,29 +2161,42 @@ class PyastBuildPass(Transform[uni.PythonModuleAst, uni.Module]):
|
|
|
2140
2161
|
)
|
|
2141
2162
|
kwarg.add_kids_left([kwarg.unpack])
|
|
2142
2163
|
defaults = [self.convert(expr) for expr in node.defaults]
|
|
2143
|
-
|
|
2144
|
-
for
|
|
2145
|
-
if
|
|
2146
|
-
|
|
2147
|
-
|
|
2148
|
-
|
|
2149
|
-
|
|
2150
|
-
|
|
2151
|
-
|
|
2152
|
-
|
|
2153
|
-
|
|
2154
|
-
|
|
2155
|
-
|
|
2156
|
-
if
|
|
2157
|
-
|
|
2164
|
+
# iterate reverse to match from the end
|
|
2165
|
+
for para in [*posonlyargs, *args][::-1]:
|
|
2166
|
+
if not defaults:
|
|
2167
|
+
break
|
|
2168
|
+
default = defaults.pop()
|
|
2169
|
+
if (
|
|
2170
|
+
default
|
|
2171
|
+
and isinstance(para, uni.ParamVar)
|
|
2172
|
+
and isinstance(default, uni.Expr)
|
|
2173
|
+
):
|
|
2174
|
+
para.value = default
|
|
2175
|
+
para.add_kids_right([para.value])
|
|
2176
|
+
|
|
2177
|
+
if kwonlyargs or args or posonlyargs or vararg or kwarg:
|
|
2178
|
+
kids = []
|
|
2179
|
+
kids.extend(posonlyargs) if posonlyargs else None
|
|
2180
|
+
kids.extend(args) if args else None
|
|
2181
|
+
kids.append(vararg) if vararg else None
|
|
2182
|
+
kids.extend(kwonlyargs) if kwonlyargs else None
|
|
2183
|
+
kids.append(kwarg) if kwarg else None
|
|
2158
2184
|
return uni.FuncSignature(
|
|
2159
|
-
|
|
2185
|
+
posonly_params=posonlyargs,
|
|
2186
|
+
params=args,
|
|
2187
|
+
varargs=vararg,
|
|
2188
|
+
kwonlyargs=kwonlyargs,
|
|
2189
|
+
kwargs=kwarg,
|
|
2160
2190
|
return_type=None,
|
|
2161
|
-
kid=
|
|
2191
|
+
kid=kids,
|
|
2162
2192
|
)
|
|
2163
2193
|
else:
|
|
2164
2194
|
return uni.FuncSignature(
|
|
2195
|
+
posonly_params=posonlyargs,
|
|
2165
2196
|
params=[],
|
|
2197
|
+
varargs=vararg,
|
|
2198
|
+
kwonlyargs=kwonlyargs,
|
|
2199
|
+
kwargs=kwarg,
|
|
2166
2200
|
return_type=None,
|
|
2167
2201
|
kid=[self.operator(Tok.LPAREN, "("), self.operator(Tok.RPAREN, ")")],
|
|
2168
2202
|
)
|
|
@@ -75,7 +75,7 @@ class SymTabBuildPass(UniPass):
|
|
|
75
75
|
def exit_module_path(self, node: uni.ModulePath) -> None:
|
|
76
76
|
if node.alias:
|
|
77
77
|
node.alias.sym_tab.def_insert(node.alias, single_decl="import")
|
|
78
|
-
elif node.path
|
|
78
|
+
elif node.path:
|
|
79
79
|
if node.parent_of_type(uni.Import) and not (
|
|
80
80
|
node.parent_of_type(uni.Import).from_loc
|
|
81
81
|
and node.parent_of_type(uni.Import).is_jac
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
|
|
2
|
+
class Foo {
|
|
3
|
+
def bar(self: Foo, a:int) -> None {}
|
|
4
|
+
static def baz(self: int, a: int) -> None {}
|
|
5
|
+
}
|
|
6
|
+
|
|
7
|
+
def foo(a: int, b: int, /, c:int, d:int, *args:int, e:int, f:int, **kwargs:int) -> None {}
|
|
8
|
+
def bar(a: int, b: int, /, c:int, d:int, e:int, f:int) -> None {}
|
|
9
|
+
def baz(a: int, *, b:int) -> None {}
|
|
10
|
+
|
|
11
|
+
with entry {
|
|
12
|
+
f = Foo();
|
|
13
|
+
|
|
14
|
+
f.bar();
|
|
15
|
+
f.bar(1);
|
|
16
|
+
f.bar(1, 2);
|
|
17
|
+
|
|
18
|
+
f.baz();
|
|
19
|
+
f.baz(1);
|
|
20
|
+
f.baz(1, 2);
|
|
21
|
+
|
|
22
|
+
foo(1, 2, 3, d=4, e=5, f=6, g=7, h=8); # c is positional and d is named
|
|
23
|
+
foo(1, 2, 3, 4, 5, 6, 7, 8, e=5, f=6); # marching extra with *args
|
|
24
|
+
foo(1, 2, d=3, e=4, f=5, c=4); # order does not matter for named
|
|
25
|
+
|
|
26
|
+
foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
|
|
27
|
+
foo(1, b=2, c=3, d=4, e=5, f=6); # b is positional only
|
|
28
|
+
|
|
29
|
+
bar(1, 2, 3, 4, 5, f=6);
|
|
30
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
31
|
+
bar(1, 2, 3, 4, 5, 6, c=3); # already matched
|
|
32
|
+
bar(1, 2, 3, 4, 5, 6, h=1); # h is not matched
|
|
33
|
+
|
|
34
|
+
baz(a=1, b=2);
|
|
35
|
+
baz(1, b=2); # a can be both positional and keyword
|
|
36
|
+
baz(1, 2); # 'b' can only be keyword arg
|
|
37
|
+
}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
|
|
2
|
+
class Foo {
|
|
3
|
+
def first_is_self(self: Foo) -> None {}
|
|
4
|
+
|
|
5
|
+
def with_default_args(self: Foo, a:int, b:int=42) -> None {}
|
|
6
|
+
}
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
with entry {
|
|
10
|
+
f = Foo();
|
|
11
|
+
f.first_is_self(); # <-- Ok
|
|
12
|
+
f.first_is_self(f); # <-- Error
|
|
13
|
+
|
|
14
|
+
f.with_default_args(1); # <-- Ok
|
|
15
|
+
f.with_default_args(1, 2); # <-- Ok
|
|
16
|
+
f.with_default_args(1, 2, 3); # <-- Error
|
|
17
|
+
f.with_default_args(); # <-- Error
|
|
18
|
+
}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
|
|
2
|
+
obj Animal {}
|
|
3
|
+
obj Cat(Animal) {}
|
|
4
|
+
obj Lion(Cat) {}
|
|
5
|
+
|
|
6
|
+
obj NotAnimal {}
|
|
7
|
+
|
|
8
|
+
def animal_func(a: Animal) -> None {}
|
|
9
|
+
|
|
10
|
+
with entry {
|
|
11
|
+
cat: Cat = Cat();
|
|
12
|
+
lion: Lion = Lion();
|
|
13
|
+
not_animal: NotAnimal = NotAnimal();
|
|
14
|
+
|
|
15
|
+
animal_func(cat); # <-- Ok
|
|
16
|
+
animal_func(lion); # <-- Ok
|
|
17
|
+
animal_func(not_animal); # <-- Error
|
|
18
|
+
}
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
|
|
2
|
+
# -----------------------------------------------------------------------------
|
|
3
|
+
# Simple Inheritance
|
|
4
|
+
# -----------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
node Parent {
|
|
7
|
+
has val: int = 42;
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
node Child (Parent) {
|
|
11
|
+
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
with entry {
|
|
16
|
+
c = Child();
|
|
17
|
+
c.val = 42; # <-- Ok
|
|
18
|
+
c.val = "str"; # <-- Error
|
|
19
|
+
}
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# -----------------------------------------------------------------------------
|
|
23
|
+
# A Complex Inheritance
|
|
24
|
+
# -----------------------------------------------------------------------------
|
|
25
|
+
|
|
26
|
+
node Animal {
|
|
27
|
+
has name: str;
|
|
28
|
+
}
|
|
29
|
+
|
|
30
|
+
node Cat(Animal) {
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
node Lion(Cat) {
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
with entry {
|
|
38
|
+
l = Lion();
|
|
39
|
+
|
|
40
|
+
l.name = "Simba"; # <-- Ok
|
|
41
|
+
l.name = 42; # <-- Error
|
|
42
|
+
}
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import torch;
|
|
2
|
+
class Cfg {
|
|
3
|
+
def __init__(self: Cfg, max_position_embeddings: Any = 8, has_original: Any = False) {
|
|
4
|
+
self.max_position_embeddings = max_position_embeddings;
|
|
5
|
+
if has_original {
|
|
6
|
+
self.original_max_position_embeddings = (max_position_embeddings // 2);
|
|
7
|
+
}
|
|
8
|
+
}
|
|
9
|
+
}
|
|
10
|
+
class ToyModel(torch.nn.Module) {
|
|
11
|
+
def __init__(self: ToyModel, cfg: Cfg) {
|
|
12
|
+
super.init();
|
|
13
|
+
self.config = cfg;
|
|
14
|
+
self.long_inv_freq = torch.tensor([10.0, 20.0, 30.0]);
|
|
15
|
+
self.original_inv_freq = torch.tensor([1.0, 2.0, 3.0]);
|
|
16
|
+
}
|
|
17
|
+
def rope_init_fn(self: ToyModel, cfg: Any, device: Any, seq_len: int) {
|
|
18
|
+
return (torch.arange(1, 4, dtype=torch.float32, device=device), None);
|
|
19
|
+
}
|
|
20
|
+
def _longrope_frequency_update(
|
|
21
|
+
self: ToyModel,
|
|
22
|
+
position_ids: Any,
|
|
23
|
+
device: Any = 'cpu'
|
|
24
|
+
) {
|
|
25
|
+
seq_len = (torch.max(position_ids) + 1);
|
|
26
|
+
if hasattr(self.config, 'original_max_position_embeddings') {
|
|
27
|
+
original_max_position_embeddings = self.config.original_max_position_embeddings;
|
|
28
|
+
} else {
|
|
29
|
+
original_max_position_embeddings = self.config.max_position_embeddings;
|
|
30
|
+
}
|
|
31
|
+
if (seq_len > original_max_position_embeddings) {
|
|
32
|
+
self.register_buffer('inv_freq', self.long_inv_freq, persistent=False);
|
|
33
|
+
} else {
|
|
34
|
+
self.register_buffer(
|
|
35
|
+
'inv_freq',
|
|
36
|
+
self.original_inv_freq.to(device),
|
|
37
|
+
persistent=False
|
|
38
|
+
);
|
|
39
|
+
}
|
|
40
|
+
return self.inv_freq;
|
|
41
|
+
}
|
|
42
|
+
}
|
|
43
|
+
|
|
@@ -28,6 +28,18 @@ class TypeCheckerPassTests(TestCase):
|
|
|
28
28
|
^^^^^^^^^^^^^^^^^^^^^^
|
|
29
29
|
""", program.errors_had[1].pretty_print())
|
|
30
30
|
|
|
31
|
+
def test_float_types(self) -> None:
|
|
32
|
+
program = JacProgram()
|
|
33
|
+
mod = program.compile(self.fixture_abs_path("checker_float.jac"))
|
|
34
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
35
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
36
|
+
self._assert_error_pretty_found("""
|
|
37
|
+
f: float = pi; # <-- OK
|
|
38
|
+
s: str = pi; # <-- Error
|
|
39
|
+
^^^^^^^^^^^
|
|
40
|
+
""", program.errors_had[0].pretty_print())
|
|
41
|
+
|
|
42
|
+
|
|
31
43
|
def test_infer_type_of_assignment(self) -> None:
|
|
32
44
|
program = JacProgram()
|
|
33
45
|
mod = program.compile(self.fixture_abs_path("infer_type_assignment.jac"))
|
|
@@ -49,6 +61,17 @@ class TypeCheckerPassTests(TestCase):
|
|
|
49
61
|
^^^^^^^^^^^^^^^^^^^
|
|
50
62
|
""", program.errors_had[0].pretty_print())
|
|
51
63
|
|
|
64
|
+
def test_imported_sym(self) -> None:
|
|
65
|
+
program = JacProgram()
|
|
66
|
+
mod = program.compile(self.fixture_abs_path("checker/import_sym_test.jac"))
|
|
67
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
68
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
69
|
+
self._assert_error_pretty_found("""
|
|
70
|
+
a: str = foo(); # <-- Ok
|
|
71
|
+
b: int = foo(); # <-- Error
|
|
72
|
+
^^^^^^^^^^^^^^
|
|
73
|
+
""", program.errors_had[0].pretty_print())
|
|
74
|
+
|
|
52
75
|
def test_member_access_type_infered(self) -> None:
|
|
53
76
|
program = JacProgram()
|
|
54
77
|
mod = program.compile(self.fixture_abs_path("member_access_type_inferred.jac"))
|
|
@@ -59,6 +82,22 @@ class TypeCheckerPassTests(TestCase):
|
|
|
59
82
|
^^^^^^^^^
|
|
60
83
|
""", program.errors_had[0].pretty_print())
|
|
61
84
|
|
|
85
|
+
def test_inherited_symbol(self) -> None:
|
|
86
|
+
program = JacProgram()
|
|
87
|
+
mod = program.compile(self.fixture_abs_path("checker_sym_inherit.jac"))
|
|
88
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
89
|
+
self.assertEqual(len(program.errors_had), 2)
|
|
90
|
+
self._assert_error_pretty_found("""
|
|
91
|
+
c.val = 42; # <-- Ok
|
|
92
|
+
c.val = "str"; # <-- Error
|
|
93
|
+
^^^^^^^^^^^^^
|
|
94
|
+
""", program.errors_had[0].pretty_print())
|
|
95
|
+
self._assert_error_pretty_found("""
|
|
96
|
+
l.name = "Simba"; # <-- Ok
|
|
97
|
+
l.name = 42; # <-- Error
|
|
98
|
+
^^^^^^^^^^^
|
|
99
|
+
""", program.errors_had[1].pretty_print())
|
|
100
|
+
|
|
62
101
|
def test_import_symbol_type_infer(self) -> None:
|
|
63
102
|
program = JacProgram()
|
|
64
103
|
mod = program.compile(self.fixture_abs_path("import_symbol_type_infer.jac"))
|
|
@@ -104,6 +143,144 @@ class TypeCheckerPassTests(TestCase):
|
|
|
104
143
|
^^^^^^^^^^^^^^^^
|
|
105
144
|
""", program.errors_had[0].pretty_print())
|
|
106
145
|
|
|
146
|
+
def test_arity(self) -> None:
|
|
147
|
+
path = self.fixture_abs_path("checker_arity.jac")
|
|
148
|
+
program = JacProgram()
|
|
149
|
+
mod = program.compile(path)
|
|
150
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
151
|
+
self.assertEqual(len(program.errors_had), 3)
|
|
152
|
+
self._assert_error_pretty_found("""
|
|
153
|
+
f.first_is_self(f); # <-- Error
|
|
154
|
+
^
|
|
155
|
+
""", program.errors_had[0].pretty_print())
|
|
156
|
+
self._assert_error_pretty_found("""
|
|
157
|
+
f.with_default_args(1, 2, 3); # <-- Error
|
|
158
|
+
^
|
|
159
|
+
""", program.errors_had[1].pretty_print())
|
|
160
|
+
self._assert_error_pretty_found("""
|
|
161
|
+
f.with_default_args(); # <-- Error
|
|
162
|
+
^^^^^^^^^^^^^^^^^^^^^
|
|
163
|
+
""", program.errors_had[2].pretty_print())
|
|
164
|
+
|
|
165
|
+
def test_param_types(self) -> None:
|
|
166
|
+
path = self.fixture_abs_path("checker_param_types.jac")
|
|
167
|
+
program = JacProgram()
|
|
168
|
+
mod = program.compile(path)
|
|
169
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
170
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
171
|
+
self._assert_error_pretty_found("""
|
|
172
|
+
foo(A()); # <-- Ok
|
|
173
|
+
foo(B()); # <-- Error
|
|
174
|
+
^^^
|
|
175
|
+
""", program.errors_had[0].pretty_print())
|
|
176
|
+
|
|
177
|
+
def test_param_arg_match(self) -> None:
|
|
178
|
+
program = JacProgram()
|
|
179
|
+
path = self.fixture_abs_path("checker_arg_param_match.jac")
|
|
180
|
+
mod = program.compile(path)
|
|
181
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
182
|
+
self.assertEqual(len(program.errors_had), 13)
|
|
183
|
+
|
|
184
|
+
expected_errors = [
|
|
185
|
+
"""
|
|
186
|
+
Not all required parameters were provided in the function call: 'a'
|
|
187
|
+
f = Foo();
|
|
188
|
+
f.bar();
|
|
189
|
+
^^^^^^^
|
|
190
|
+
""",
|
|
191
|
+
"""
|
|
192
|
+
Too many positional arguments
|
|
193
|
+
f.bar();
|
|
194
|
+
f.bar(1);
|
|
195
|
+
f.bar(1, 2);
|
|
196
|
+
^
|
|
197
|
+
""",
|
|
198
|
+
"""
|
|
199
|
+
Not all required parameters were provided in the function call: 'self', 'a'
|
|
200
|
+
f.bar(1, 2);
|
|
201
|
+
f.baz();
|
|
202
|
+
^^^^^^^
|
|
203
|
+
""",
|
|
204
|
+
"""
|
|
205
|
+
Not all required parameters were provided in the function call: 'a'
|
|
206
|
+
f.baz();
|
|
207
|
+
f.baz(1);
|
|
208
|
+
^^^^^^^^
|
|
209
|
+
""",
|
|
210
|
+
"""
|
|
211
|
+
Not all required parameters were provided in the function call: 'f'
|
|
212
|
+
foo(1, 2, d=3, e=4, f=5, c=4); # order does not matter for named
|
|
213
|
+
foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
|
|
214
|
+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
|
215
|
+
""",
|
|
216
|
+
"""
|
|
217
|
+
Positional only parameter 'b' cannot be matched with a named argument
|
|
218
|
+
foo(1, 2, 3, d=4, e=5, g=7, h=8); # missing argument 'f'
|
|
219
|
+
foo(1, b=2, c=3, d=4, e=5, f=6); # b is positional only
|
|
220
|
+
^^^
|
|
221
|
+
""",
|
|
222
|
+
"""
|
|
223
|
+
Too many positional arguments
|
|
224
|
+
bar(1, 2, 3, 4, 5, f=6);
|
|
225
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
226
|
+
^
|
|
227
|
+
""",
|
|
228
|
+
"""
|
|
229
|
+
Too many positional arguments
|
|
230
|
+
bar(1, 2, 3, 4, 5, f=6);
|
|
231
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
232
|
+
^
|
|
233
|
+
""",
|
|
234
|
+
"""
|
|
235
|
+
Too many positional arguments
|
|
236
|
+
bar(1, 2, 3, 4, 5, f=6);
|
|
237
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
238
|
+
^
|
|
239
|
+
""",
|
|
240
|
+
"""
|
|
241
|
+
Parameter 'c' already matched
|
|
242
|
+
bar(1, 2, 3, 4, 5, f=6);
|
|
243
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
244
|
+
bar(1, 2, 3, 4, 5, 6, c=3); # already matched
|
|
245
|
+
^^^
|
|
246
|
+
""",
|
|
247
|
+
"""
|
|
248
|
+
Named argument 'h' does not match any parameter
|
|
249
|
+
bar(1, 2, 3, 4, 5, 6, 7, 8, 9); # too many args
|
|
250
|
+
bar(1, 2, 3, 4, 5, 6, c=3); # already matched
|
|
251
|
+
bar(1, 2, 3, 4, 5, 6, h=1); # h is not matched
|
|
252
|
+
^^^
|
|
253
|
+
""",
|
|
254
|
+
"""
|
|
255
|
+
Too many positional arguments
|
|
256
|
+
baz(a=1, b=2);
|
|
257
|
+
baz(1, b=2); # a can be both positional and keyword
|
|
258
|
+
baz(1, 2); # 'b' can only be keyword arg
|
|
259
|
+
^
|
|
260
|
+
""",
|
|
261
|
+
"""
|
|
262
|
+
Not all required parameters were provided in the function call: 'b'
|
|
263
|
+
baz(a=1, b=2);
|
|
264
|
+
baz(1, b=2); # a can be both positional and keyword
|
|
265
|
+
baz(1, 2); # 'b' can only be keyword arg
|
|
266
|
+
^^^^^^^^^
|
|
267
|
+
""",
|
|
268
|
+
]
|
|
269
|
+
|
|
270
|
+
for i, expected in enumerate(expected_errors):
|
|
271
|
+
self._assert_error_pretty_found(expected, program.errors_had[i].pretty_print())
|
|
272
|
+
|
|
273
|
+
def test_self_type_inference(self) -> None:
|
|
274
|
+
path = self.fixture_abs_path("checker_self_type.jac")
|
|
275
|
+
program = JacProgram()
|
|
276
|
+
mod = program.compile(path)
|
|
277
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
278
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
279
|
+
self._assert_error_pretty_found("""
|
|
280
|
+
x: str = self.i; # <-- Error
|
|
281
|
+
^^^^^^^^^^^^^^^
|
|
282
|
+
""", program.errors_had[0].pretty_print())
|
|
283
|
+
|
|
107
284
|
def test_binary_op(self) -> None:
|
|
108
285
|
program = JacProgram()
|
|
109
286
|
mod = program.compile(self.fixture_abs_path("checker_binary_op.jac"))
|
|
@@ -140,6 +317,19 @@ class TypeCheckerPassTests(TestCase):
|
|
|
140
317
|
^^^^^^^^^^^^^^
|
|
141
318
|
""", program.errors_had[0].pretty_print())
|
|
142
319
|
|
|
320
|
+
def test_checker_cat_is_animal(self) -> None:
|
|
321
|
+
path = self.fixture_abs_path("checker_cat_is_animal.jac")
|
|
322
|
+
program = JacProgram()
|
|
323
|
+
mod = program.compile(path)
|
|
324
|
+
TypeCheckPass(ir_in=mod, prog=program)
|
|
325
|
+
self.assertEqual(len(program.errors_had), 1)
|
|
326
|
+
self._assert_error_pretty_found("""
|
|
327
|
+
animal_func(cat); # <-- Ok
|
|
328
|
+
animal_func(lion); # <-- Ok
|
|
329
|
+
animal_func(not_animal); # <-- Error
|
|
330
|
+
^^^^^^^^^^
|
|
331
|
+
""", program.errors_had[0].pretty_print())
|
|
332
|
+
|
|
143
333
|
def test_checker_import_missing_module(self) -> None:
|
|
144
334
|
path = self.fixture_abs_path("checker_import_missing_module.jac")
|
|
145
335
|
program = JacProgram()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Test ast build pass module."""
|
|
2
|
+
|
|
3
|
+
import io
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
from jaclang.compiler.program import JacProgram, py_code_gen
|
|
8
|
+
from jaclang.utils.test import TestCase
|
|
9
|
+
from jaclang.settings import settings
|
|
10
|
+
from jaclang.compiler.passes.main import PreDynamoPass
|
|
11
|
+
|
|
12
|
+
class PreDynamoPassTests(TestCase):
|
|
13
|
+
"""Test pass module."""
|
|
14
|
+
|
|
15
|
+
TargetPass = PreDynamoPass
|
|
16
|
+
|
|
17
|
+
def setUp(self) -> None:
|
|
18
|
+
"""Set up test."""
|
|
19
|
+
settings.predynamo_pass = True
|
|
20
|
+
return super().setUp()
|
|
21
|
+
|
|
22
|
+
def tearDown(self) -> None:
|
|
23
|
+
"""Tear down test."""
|
|
24
|
+
settings.predynamo_pass = False
|
|
25
|
+
# Remove PreDynamoPass from global py_code_gen list if it was added
|
|
26
|
+
if PreDynamoPass in py_code_gen:
|
|
27
|
+
py_code_gen.remove(PreDynamoPass)
|
|
28
|
+
return super().tearDown()
|
|
29
|
+
|
|
30
|
+
def test_predynamo_where_assign(self) -> None:
|
|
31
|
+
"""Test torch.where transformation."""
|
|
32
|
+
captured_output = io.StringIO()
|
|
33
|
+
sys.stdout = captured_output
|
|
34
|
+
settings.predynamo_pass = True
|
|
35
|
+
code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_assign.jac"))
|
|
36
|
+
sys.stdout = sys.__stdout__
|
|
37
|
+
self.assertIn("torch.where", code_gen.unparse())
|
|
38
|
+
|
|
39
|
+
def test_predynamo_where_return(self) -> None:
|
|
40
|
+
"""Test torch.where transformation."""
|
|
41
|
+
captured_output = io.StringIO()
|
|
42
|
+
sys.stdout = captured_output
|
|
43
|
+
code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_where_return.jac"))
|
|
44
|
+
sys.stdout = sys.__stdout__
|
|
45
|
+
self.assertIn("torch.where", code_gen.unparse())
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def test_predynamo_fix3(self) -> None:
|
|
49
|
+
"""Test torch.where transformation."""
|
|
50
|
+
captured_output = io.StringIO()
|
|
51
|
+
sys.stdout = captured_output
|
|
52
|
+
code_gen = JacProgram().compile(self.fixture_abs_path("predynamo_fix3.jac"))
|
|
53
|
+
sys.stdout = sys.__stdout__
|
|
54
|
+
unparsed_code = code_gen.unparse()
|
|
55
|
+
self.assertIn("__inv_freq = torch.where(", unparsed_code)
|
|
56
|
+
self.assertIn("self.register_buffer('inv_freq', __inv_freq, persistent=False);", unparsed_code)
|